From 756df9385ff7d6dabed53110931d4d8d3f8e5d09 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 28 Jul 2020 12:57:19 +0800 Subject: [PATCH] make stmt cache size configurable --- pkg/orm/constant.go | 1 + pkg/orm/db_alias.go | 105 ++++++++++++++++++++++----------------- pkg/orm/db_alias_test.go | 53 ++++++++++++++++++++ pkg/orm/orm.go | 35 +++++++++++-- 4 files changed, 144 insertions(+), 50 deletions(-) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go index 14f40a7b..54550492 100644 --- a/pkg/orm/constant.go +++ b/pkg/orm/constant.go @@ -18,4 +18,5 @@ const ( MaxIdleConnsKey = "MaxIdleConns" MaxOpenConnsKey = "MaxOpenConns" ConnMaxLifetimeKey = "ConnMaxLifetime" + MaxStmtCacheSize = "MaxStmtCacheSize" ) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..a9961649 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -109,8 +109,9 @@ func (ac *_dbCache) getDefault() (al *alias) { type DB struct { *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache + DB *sql.DB + stmtDecorators *lru.Cache + stmtDecoratorsLimit int } var _ dbQuerier = new(DB) @@ -165,16 +166,14 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error } func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) + return d.ExecContext(context.Background(), query, args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if d.stmtDecorators == nil { + return d.DB.ExecContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -185,16 +184,14 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) + return d.QueryContext(context.Background(), query, args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if d.stmtDecorators == nil { + return d.DB.QueryContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -205,24 +202,21 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - + return d.QueryRowContext(context.Background(), query, args...) } func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + if d.stmtDecorators == nil { + return d.DB.QueryRowContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } stmt := sd.getStmt() defer sd.release() - return stmt.QueryRowContext(ctx, args) + return stmt.QueryRowContext(ctx, args...) } type TxDB struct { @@ -345,14 +339,31 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al := new(alias) al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } if dr, ok := drivers[driverName]; ok { @@ -371,12 +382,22 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } + detectTZ(al) + + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + return al, nil } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error { + _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } @@ -388,7 +409,6 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common al *alias ) - kvs := common.NewKVs(params...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -396,23 +416,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common goto end } - al, err = addAliasWthDB(aliasName, driverName, db) + al, err = addAliasWthDB(aliasName, driverName, db, params...) if err != nil { goto end } al.DataSource = dataSource - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - end: if err != nil { if db != nil { @@ -517,9 +527,12 @@ func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { } } -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { +func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { + cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { value.(*stmtDecorator).destroy() }) - return cache + if err != nil { + return nil, err + } + return cache, nil } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..85cdd82f 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -42,3 +42,56 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.MaxOpenConns, 300) assert.Equal(t, al.ConnMaxLifetime, time.Minute) } + +func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: -1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 0, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 1) +} + +func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 841, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) +} + diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 8ef761f4..441fcfc0 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -58,6 +58,8 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/common" + lru "github.com/hashicorp/golang-lru" "os" "reflect" "sync" @@ -609,7 +611,7 @@ func NewOrm() Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { var al *alias if dr, ok := drivers[driverName]; ok { @@ -620,16 +622,41 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { return nil, fmt.Errorf("driver name `%s` have not registered", driverName) } + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } detectTZ(al) + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + o := new(orm) o.alias = al