diff --git a/orm/db_alias.go b/orm/db_alias.go index 094aeb46..81d0d329 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" @@ -106,8 +107,8 @@ func (ac *_dbCache) getDefault() (al *alias) { type DB struct { *sync.RWMutex - DB *sql.DB - stmts map[string]*sql.Stmt + DB *sql.DB + stmtDecorators *lru.Cache } func (d *DB) Begin() (*sql.Tx, error) { @@ -118,19 +119,20 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -func (d *DB) getStmt(query string) (*sql.Stmt, error) { +//su must call release to release *sql.Stmt after using +func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() - c, ok := d.stmts[query] + c, ok := d.stmtDecorators.Get(query) d.RUnlock() if ok { - return c, nil + return c.(*stmtDecorator), nil } d.Lock() - c, ok = d.stmts[query] + c, ok = d.stmtDecorators.Get(query) if ok { d.Unlock() - return c, nil + return c.(*stmtDecorator), nil } stmt, err := d.Prepare(query) @@ -138,9 +140,11 @@ func (d *DB) getStmt(query string) (*sql.Stmt, error) { d.Unlock() return nil, err } - d.stmts[query] = stmt + sd := newStmtDecorator(stmt) + d.stmtDecorators.Add(query, sd) d.Unlock() - return stmt, nil + + return sd, nil } func (d *DB) Prepare(query string) (*sql.Stmt, error) { @@ -152,52 +156,63 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error } func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err } + stmt := sd.acquire() + defer sd.release() return stmt.Exec(args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err } + stmt := sd.acquire() + defer sd.release() return stmt.ExecContext(ctx, args...) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err } + stmt := sd.acquire() + defer sd.release() return stmt.Query(args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err } + stmt := sd.acquire() + defer sd.release() return stmt.QueryContext(ctx, args...) } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } + stmt := sd.acquire() + defer sd.release() return stmt.QueryRow(args...) } func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - - stmt, err := d.getStmt(query) + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } + stmt := sd.acquire() + defer sd.release() return stmt.QueryRowContext(ctx, args) } @@ -275,9 +290,9 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmts: make(map[string]*sql.Stmt), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), } if dr, ok := drivers[driverName]; ok { @@ -403,3 +418,39 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) } + +type stmtDecorator struct { + wg sync.WaitGroup + lastUse int64 + stmt *sql.Stmt +} + +func (s *stmtDecorator) acquire() *sql.Stmt{ + s.wg.Add(1) + s.lastUse = time.Now().Unix() + return s.stmt +} + +func (s *stmtDecorator) release() { + s.wg.Done() +} + +//garbage recycle for stmt +func (s *stmtDecorator) destroy() { + s.wg.Wait() + _ = s.stmt.Close() +} + +func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { + return &stmtDecorator{ + stmt: sqlStmt, + lastUse: time.Now().Unix(), + } +} + +func newStmtDecoratorLruWithEvict() *lru.Cache { + cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { + value.(*stmtDecorator).destroy() + }) + return cache +} diff --git a/orm/orm.go b/orm/orm.go index 11e38fd9..0551b1cd 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -559,9 +559,9 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmts: make(map[string]*sql.Stmt), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), } detectTZ(al)