From a28d294a83699d6fc8cfb5ca019ac33b0f975313 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 23 Jun 2020 13:46:19 +0800 Subject: [PATCH] upgrade acquire method return *sql.Stmt --- orm/db_alias.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index a69259ed..a38f1d60 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -125,7 +125,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { c, ok := d.stmtDecorators.Get(query) d.RUnlock() if ok { - c.(*stmtDecorator).acquire() return c.(*stmtDecorator), nil } @@ -133,7 +132,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { c, ok = d.stmtDecorators.Get(query) if ok { d.Unlock() - c.(*stmtDecorator).acquire() return c.(*stmtDecorator), nil } @@ -146,7 +144,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.stmtDecorators.Add(query, sd) d.Unlock() - sd.acquire() return sd, nil } @@ -163,7 +160,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { if err != nil { return nil, err } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.Exec(args...) } @@ -173,7 +170,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) if err != nil { return nil, err } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.ExecContext(ctx, args...) } @@ -183,7 +180,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { if err != nil { return nil, err } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.Query(args...) } @@ -193,7 +190,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.QueryContext(ctx, args...) } @@ -203,7 +200,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { if err != nil { panic(err) } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.QueryRow(args...) @@ -214,7 +211,7 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac if err != nil { panic(err) } - stmt := sd.getStmt() + stmt := sd.acquire() defer sd.release() return stmt.QueryRowContext(ctx, args) } @@ -423,15 +420,14 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { type stmtDecorator struct { wg sync.WaitGroup + lastUse int64 stmt *sql.Stmt } -func (s *stmtDecorator) getStmt() *sql.Stmt { - return s.stmt -} - -func (s *stmtDecorator) acquire() { +func (s *stmtDecorator) acquire() *sql.Stmt{ s.wg.Add(1) + s.lastUse = time.Now().Unix() + return s.stmt } func (s *stmtDecorator) release() { @@ -447,6 +443,7 @@ func (s *stmtDecorator) destroy() { func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { return &stmtDecorator{ stmt: sqlStmt, + lastUse: time.Now().Unix(), } }