From 6d9862b924166430bb603127db53a713b4e9c48e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 23 Jun 2020 22:29:41 +0800 Subject: [PATCH] acquire() in Lock --- orm/db_alias.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index a38f1d60..36625f0a 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -123,14 +123,17 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) - d.RUnlock() if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() return c.(*stmtDecorator), nil } + d.RUnlock() d.Lock() c, ok = d.stmtDecorators.Get(query) if ok { + c.(*stmtDecorator).acquire() d.Unlock() return c.(*stmtDecorator), nil } @@ -141,6 +144,7 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { return nil, err } sd := newStmtDecorator(stmt) + sd.acquire() d.stmtDecorators.Add(query, sd) d.Unlock() @@ -160,7 +164,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.Exec(args...) } @@ -170,7 +174,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.ExecContext(ctx, args...) } @@ -180,7 +184,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.Query(args...) } @@ -190,7 +194,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryContext(ctx, args...) } @@ -200,7 +204,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { if err != nil { panic(err) } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryRow(args...) @@ -211,7 +215,7 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac if err != nil { panic(err) } - stmt := sd.acquire() + stmt := sd.getStmt() defer sd.release() return stmt.QueryRowContext(ctx, args) } @@ -424,10 +428,13 @@ type stmtDecorator struct { stmt *sql.Stmt } -func (s *stmtDecorator) acquire() *sql.Stmt{ +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +func (s *stmtDecorator) acquire() { s.wg.Add(1) s.lastUse = time.Now().Unix() - return s.stmt } func (s *stmtDecorator) release() {