diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 336ec54b..5f1e3ea3 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -341,12 +340,30 @@ func detectTZ(al *alias) { } func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + if _, ok := dataBaseCache.get(aliasName); ok { + return nil, existErr + } + + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + if !dataBaseCache.add(aliasName, al) { + return nil, existErr + } + + return al, nil +} + +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ kvs := common.NewKVs(params...) var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) if maxStmtCacheSize > 0 { _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) if errC != nil { @@ -379,18 +396,20 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - detectTZ(al) - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxIdleConns(al, m) + } + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxOpenConns(al, m) + } + }).IfContains(connMaxLifetimeKey, func(value interface{}) { + if m, ok := value.(time.Duration); ok { + SetConnMaxLifetime(al, m) + } }) return al, nil @@ -458,21 +477,18 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) +func SetMaxIdleConns(al *alias, maxIdleConns int) { al.MaxIdleConns = maxIdleConns al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) +func SetMaxOpenConns(al *alias, maxOpenConns int) { al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) } -func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { - al := getDbAlias(aliasName) +func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { al.ConnMaxLifetime = lifeTime al.DB.DB.SetConnMaxLifetime(lifeTime) } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index c8b4aad1..ebf93a86 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -37,10 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: -1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -50,10 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 0, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -63,10 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -76,10 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 841, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go index 8900d599..551c7357 100644 --- a/pkg/orm/db_hints.go +++ b/pkg/orm/db_hints.go @@ -40,6 +40,7 @@ const ( maxIdleConnectionsKey = "MaxIdleConnections" maxOpenConnectionsKey = "MaxOpenConnections" connMaxLifetimeKey = "ConnMaxLifetime" + maxStmtCacheSizeKey = "MaxStmtCacheSize" ) var _ common.KV = new(Hint) @@ -59,6 +60,11 @@ func ConnMaxLifetime(v time.Duration) *Hint { return NewHint(connMaxLifetimeKey, v) } +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) *Hint { + return NewHint(maxStmtCacheSizeKey, v) +} + // NewHint return a hint func NewHint(key interface{}, value interface{}) *Hint { return &Hint{ diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go index 9b62a730..13f8ccde 100644 --- a/pkg/orm/db_hints_test.go +++ b/pkg/orm/db_hints_test.go @@ -67,3 +67,10 @@ func TestMaxIdleConnections(t *testing.T) { assert.Equal(t, hint.GetValue(), i) assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) } + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) +} \ No newline at end of file diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 441fcfc0..b2f1e693 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -59,10 +59,8 @@ import ( "errors" "fmt" "github.com/astaxie/beego/pkg/common" - lru "github.com/hashicorp/golang-lru" "os" "reflect" - "sync" "time" "github.com/astaxie/beego/logs" @@ -612,51 +610,11 @@ func NewOrm() Ormer { // NewOrmWithDB create a new ormer object with specify *sql.DB for query func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err } - kvs := common.NewKVs(params...) - - var stmtCache *lru.Cache - var stmtCacheSize int - - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) - if errC != nil { - return nil, errC - } else { - stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize - } - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } - - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - o := new(orm) o.alias = al