diff --git a/orm/db_alias.go b/orm/db_alias.go index 81d0d329..cf6a5935 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -123,14 +123,17 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) - d.RUnlock() if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() return c.(*stmtDecorator), nil } + d.RUnlock() d.Lock() c, ok = d.stmtDecorators.Get(query) if ok { + c.(*stmtDecorator).acquire() d.Unlock() return c.(*stmtDecorator), nil } @@ -141,6 +144,7 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { return nil, err } sd := newStmtDecorator(stmt) + sd.acquire() d.stmtDecorators.Add(query, sd) d.Unlock() @@ -160,7 +164,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.Exec(args...) } @@ -170,7 +174,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.ExecContext(ctx, args...) } @@ -180,7 +184,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.Query(args...) } @@ -190,7 +194,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryContext(ctx, args...) } @@ -200,7 +204,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { if err != nil { panic(err) } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryRow(args...) @@ -211,7 +215,7 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac if err != nil { panic(err) } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryRowContext(ctx, args) } @@ -425,10 +429,13 @@ type stmtDecorator struct { stmt *sql.Stmt } -func (s *stmtDecorator) acquire() *sql.Stmt{ +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +func (s *stmtDecorator) acquire() { s.wg.Add(1) s.lastUse = time.Now().Unix() - return s.stmt } func (s *stmtDecorator) release() { @@ -437,8 +444,10 @@ func (s *stmtDecorator) release() { //garbage recycle for stmt func (s *stmtDecorator) destroy() { - s.wg.Wait() - _ = s.stmt.Close() + go func() { + s.wg.Wait() + _ = s.stmt.Close() + }() } func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {