From cfff0f3b46e18d1af7f7b4e9a2a9dad41e451f86 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sat, 25 Jul 2020 00:00:34 +0800 Subject: [PATCH 1/6] fix memory leak of request context --- context/input.go | 32 ++++++++++++++++++++++++++++++-- context/output.go | 6 ++++++ router.go | 4 ++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/context/input.go b/context/input.go index 7b522c36..da066a86 100644 --- a/context/input.go +++ b/context/input.go @@ -323,8 +323,36 @@ func (input *BeegoInput) SetParam(key, val string) { // This function is used to clear parameters so they may be reset between filter // passes. func (input *BeegoInput) ResetParams() { - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] + if len(input.pnames) > 0 { + input.pnames = input.pnames[:0] + } + if len(input.pvalues) > 0 { + input.pvalues = input.pvalues[:0] + } +} + +// ResetData: reset data +func (input *BeegoInput) ResetData() { + input.dataLock.Lock() + if input.data != nil { + input.data = nil + } + input.dataLock.Unlock() +} + +// ResetBody: reset body +func (input *BeegoInput) ResetBody() { + if len(input.RequestBody) > 0 { + input.RequestBody = []byte{} + } +} + +// Clear: clear all data in input +func (input *BeegoInput) Clear() { + input.ResetParams() + input.ResetData() + input.ResetBody() + } // Query returns input data item string by a given string. diff --git a/context/output.go b/context/output.go index 238dcf45..eaa75720 100644 --- a/context/output.go +++ b/context/output.go @@ -50,9 +50,15 @@ func NewOutput() *BeegoOutput { // Reset init BeegoOutput func (output *BeegoOutput) Reset(ctx *Context) { output.Context = ctx + output.Clear() +} + +// Clear: clear all data in output +func (output *BeegoOutput) Clear() { output.Status = 0 } + // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) diff --git a/router.go b/router.go index a993a1af..af0a7ceb 100644 --- a/router.go +++ b/router.go @@ -319,6 +319,10 @@ func (p *ControllerRegister) GetContext() *beecontext.Context { // GiveBackContext put the ctx into pool so that it could be reuse func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + // clear input cached data + ctx.Input.Clear() + // clear output cached data + ctx.Output.Clear() p.pool.Put(ctx) } From 16d71893cdd4e914d740223af2bd983e28a1f96c Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 26 Jul 2020 19:40:13 +0800 Subject: [PATCH 2/6] orm.rawPrepare support FlatParams --- pkg/orm/orm_raw.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 5e05eded..2f214f93 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -32,7 +32,8 @@ func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { if o.closed { return nil, ErrStmtClosed } - return o.stmt.Exec(args...) + flatParams := getFlatParams(nil, args, o.rs.orm.alias.TZ) + return o.stmt.Exec(flatParams...) } func (o *rawPrepare) Close() error { From 2386c9c80d1d38f44f23c134c748810a9f1add0a Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 26 Jul 2020 22:37:42 +0800 Subject: [PATCH 3/6] delete useless if-stmt --- context/input.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/context/input.go b/context/input.go index da066a86..fb01648f 100644 --- a/context/input.go +++ b/context/input.go @@ -323,28 +323,20 @@ func (input *BeegoInput) SetParam(key, val string) { // This function is used to clear parameters so they may be reset between filter // passes. func (input *BeegoInput) ResetParams() { - if len(input.pnames) > 0 { - input.pnames = input.pnames[:0] - } - if len(input.pvalues) > 0 { - input.pvalues = input.pvalues[:0] - } + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] } // ResetData: reset data func (input *BeegoInput) ResetData() { input.dataLock.Lock() - if input.data != nil { - input.data = nil - } + input.data = nil input.dataLock.Unlock() } // ResetBody: reset body func (input *BeegoInput) ResetBody() { - if len(input.RequestBody) > 0 { - input.RequestBody = []byte{} - } + input.RequestBody = []byte{} } // Clear: clear all data in input From 756df9385ff7d6dabed53110931d4d8d3f8e5d09 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 28 Jul 2020 12:57:19 +0800 Subject: [PATCH 4/6] make stmt cache size configurable --- pkg/orm/constant.go | 1 + pkg/orm/db_alias.go | 105 ++++++++++++++++++++++----------------- pkg/orm/db_alias_test.go | 53 ++++++++++++++++++++ pkg/orm/orm.go | 35 +++++++++++-- 4 files changed, 144 insertions(+), 50 deletions(-) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go index 14f40a7b..54550492 100644 --- a/pkg/orm/constant.go +++ b/pkg/orm/constant.go @@ -18,4 +18,5 @@ const ( MaxIdleConnsKey = "MaxIdleConns" MaxOpenConnsKey = "MaxOpenConns" ConnMaxLifetimeKey = "ConnMaxLifetime" + MaxStmtCacheSize = "MaxStmtCacheSize" ) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..a9961649 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -109,8 +109,9 @@ func (ac *_dbCache) getDefault() (al *alias) { type DB struct { *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache + DB *sql.DB + stmtDecorators *lru.Cache + stmtDecoratorsLimit int } var _ dbQuerier = new(DB) @@ -165,16 +166,14 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error } func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) + return d.ExecContext(context.Background(), query, args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if d.stmtDecorators == nil { + return d.DB.ExecContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -185,16 +184,14 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) + return d.QueryContext(context.Background(), query, args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if d.stmtDecorators == nil { + return d.DB.QueryContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -205,24 +202,21 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - + return d.QueryRowContext(context.Background(), query, args...) } func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + if d.stmtDecorators == nil { + return d.DB.QueryRowContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } stmt := sd.getStmt() defer sd.release() - return stmt.QueryRowContext(ctx, args) + return stmt.QueryRowContext(ctx, args...) } type TxDB struct { @@ -345,14 +339,31 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al := new(alias) al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } if dr, ok := drivers[driverName]; ok { @@ -371,12 +382,22 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } + detectTZ(al) + + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + return al, nil } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error { + _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } @@ -388,7 +409,6 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common al *alias ) - kvs := common.NewKVs(params...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -396,23 +416,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common goto end } - al, err = addAliasWthDB(aliasName, driverName, db) + al, err = addAliasWthDB(aliasName, driverName, db, params...) if err != nil { goto end } al.DataSource = dataSource - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - end: if err != nil { if db != nil { @@ -517,9 +527,12 @@ func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { } } -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { +func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { + cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { value.(*stmtDecorator).destroy() }) - return cache + if err != nil { + return nil, err + } + return cache, nil } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..85cdd82f 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -42,3 +42,56 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.MaxOpenConns, 300) assert.Equal(t, al.ConnMaxLifetime, time.Minute) } + +func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: -1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 0, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 1, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 1) +} + +func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxStmtCacheSize, + Value: 841, + }) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) +} + diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 8ef761f4..441fcfc0 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -58,6 +58,8 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/common" + lru "github.com/hashicorp/golang-lru" "os" "reflect" "sync" @@ -609,7 +611,7 @@ func NewOrm() Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { var al *alias if dr, ok := drivers[driverName]; ok { @@ -620,16 +622,41 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { return nil, fmt.Errorf("driver name `%s` have not registered", driverName) } + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } detectTZ(al) + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) + o := new(orm) o.alias = al From e8facd28f544bb4edbe2a533ca62e6878e025958 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 28 Jul 2020 17:37:36 +0800 Subject: [PATCH 5/6] wrap kv --- pkg/common/kv.go | 23 +++++++++++--- pkg/common/kv_test.go | 2 +- pkg/orm/constant.go | 21 ------------ pkg/orm/db_alias.go | 11 ++++--- pkg/orm/db_alias_test.go | 16 +++------- pkg/orm/db_hints.go | 68 +++++++++++++++++++++++++++++++++++++++ pkg/orm/db_hints_test.go | 69 ++++++++++++++++++++++++++++++++++++++++ pkg/orm/models_test.go | 6 +--- 8 files changed, 168 insertions(+), 48 deletions(-) delete mode 100644 pkg/orm/constant.go create mode 100644 pkg/orm/db_hints.go create mode 100644 pkg/orm/db_hints_test.go diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 508e6b5c..86a50132 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -14,14 +14,29 @@ package common -// KV is common structure to store key-value data. +type KV interface { + GetKey() interface{} + GetValue() interface{} +} + +// SimpleKV is common structure to store key-value data. // when you need something like Pair, you can use this -type KV struct { +type SimpleKV struct { Key interface{} Value interface{} } -// KVs will store KV collection as map +var _ KV = new(SimpleKV) + +func (s *SimpleKV) GetKey() interface{} { + return s.Key +} + +func (s *SimpleKV) GetValue() interface{} { + return s.Value +} + +// KVs will store SimpleKV collection as map type KVs struct { kvs map[interface{}]interface{} } @@ -63,7 +78,7 @@ func NewKVs(kvs ...KV) *KVs { kvs: make(map[interface{}]interface{}, len(kvs)), } for _, kv := range kvs { - res.kvs[kv.Key] = kv.Value + res.kvs[kv.GetKey()] = kv.GetValue() } return res } diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index 45adf5ff..275c6753 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -22,7 +22,7 @@ import ( func TestKVs(t *testing.T) { key := "my-key" - kvs := NewKVs(KV{ + kvs := NewKVs(&SimpleKV{ Key: key, Value: 12, }) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go deleted file mode 100644 index 14f40a7b..00000000 --- a/pkg/orm/constant.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -const ( - MaxIdleConnsKey = "MaxIdleConns" - MaxOpenConnsKey = "MaxOpenConns" - ConnMaxLifetimeKey = "ConnMaxLifetime" -) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..53c668ae 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -381,14 +382,14 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { var ( err error db *sql.DB al *alias ) - kvs := common.NewKVs(params...) + kvs := common.NewKVs(hints...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -405,11 +406,11 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common detectTZ(al) - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + }).IfContains(connMaxLifetimeKey, func(value interface{}) { SetConnMaxLifetime(al.Name, value.(time.Duration)) }) diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index a0cdcd44..28dd2e6b 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -19,21 +19,13 @@ import ( "time" "github.com/stretchr/testify/assert" - - "github.com/astaxie/beego/pkg/common" ) func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }, common.KV{ - Key: MaxOpenConnsKey, - Value: 300, - }, common.KV{ - Key: ConnMaxLifetimeKey, - Value: time.Minute, - }) + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go new file mode 100644 index 00000000..8900d599 --- /dev/null +++ b/pkg/orm/db_hints.go @@ -0,0 +1,68 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/common" + "time" +) + +type Hint struct { + key interface{} + value interface{} +} + +var _ common.KV = new(Hint) + +// GetKey return key +func (s *Hint) GetKey() interface{} { + return s.key +} + +// GetValue return value +func (s *Hint) GetValue() interface{} { + return s.value +} + +const ( + maxIdleConnectionsKey = "MaxIdleConnections" + maxOpenConnectionsKey = "MaxOpenConnections" + connMaxLifetimeKey = "ConnMaxLifetime" +) + +var _ common.KV = new(Hint) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(v int) *Hint { + return NewHint(maxIdleConnectionsKey, v) +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(v int) *Hint { + return NewHint(maxOpenConnectionsKey, v) +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) *Hint { + return NewHint(connMaxLifetimeKey, v) +} + +// NewHint return a hint +func NewHint(key interface{}, value interface{}) *Hint { + return &Hint{ + key: key, + value: value, + } +} diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go new file mode 100644 index 00000000..9b62a730 --- /dev/null +++ b/pkg/orm/db_hints_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewHint_time(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_int(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_float(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestMaxOpenConnections(t *testing.T) { + i := 887423 + hint := MaxOpenConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey) +} + +func TestConnMaxLifetime(t *testing.T) { + i := time.Hour + hint := ConnMaxLifetime(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), connMaxLifetimeKey) +} + +func TestMaxIdleConnections(t *testing.T) { + i := 42316 + hint := MaxIdleConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 4c00050d..ae166dc7 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -28,7 +28,6 @@ import ( // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" - "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -489,10 +488,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) From e87de70c6deb0eb13f5b48596511240da9e2551e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Wed, 29 Jul 2020 00:45:41 +0800 Subject: [PATCH 6/6] adapt wrapping kv --- pkg/orm/db_alias.go | 52 ++++++++++++++++++++++++++-------------- pkg/orm/db_alias_test.go | 20 ++++------------ pkg/orm/db_hints.go | 6 +++++ pkg/orm/db_hints_test.go | 7 ++++++ pkg/orm/orm.go | 48 +++---------------------------------- 5 files changed, 54 insertions(+), 79 deletions(-) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 336ec54b..5f1e3ea3 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -341,12 +340,30 @@ func detectTZ(al *alias) { } func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + if _, ok := dataBaseCache.get(aliasName); ok { + return nil, existErr + } + + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + if !dataBaseCache.add(aliasName, al) { + return nil, existErr + } + + return al, nil +} + +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ kvs := common.NewKVs(params...) var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) + maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) if maxStmtCacheSize > 0 { _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) if errC != nil { @@ -379,18 +396,20 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - detectTZ(al) - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxIdleConns(al, m) + } + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxOpenConns(al, m) + } + }).IfContains(connMaxLifetimeKey, func(value interface{}) { + if m, ok := value.(time.Duration); ok { + SetConnMaxLifetime(al, m) + } }) return al, nil @@ -458,21 +477,18 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) +func SetMaxIdleConns(al *alias, maxIdleConns int) { al.MaxIdleConns = maxIdleConns al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) +func SetMaxOpenConns(al *alias, maxOpenConns int) { al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) } -func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { - al := getDbAlias(aliasName) +func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { al.ConnMaxLifetime = lifeTime al.DB.DB.SetConnMaxLifetime(lifeTime) } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index c8b4aad1..ebf93a86 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -37,10 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: -1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -50,10 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 0, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -63,10 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 1, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -76,10 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxStmtCacheSize, - Value: 841, - }) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go index 8900d599..551c7357 100644 --- a/pkg/orm/db_hints.go +++ b/pkg/orm/db_hints.go @@ -40,6 +40,7 @@ const ( maxIdleConnectionsKey = "MaxIdleConnections" maxOpenConnectionsKey = "MaxOpenConnections" connMaxLifetimeKey = "ConnMaxLifetime" + maxStmtCacheSizeKey = "MaxStmtCacheSize" ) var _ common.KV = new(Hint) @@ -59,6 +60,11 @@ func ConnMaxLifetime(v time.Duration) *Hint { return NewHint(connMaxLifetimeKey, v) } +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) *Hint { + return NewHint(maxStmtCacheSizeKey, v) +} + // NewHint return a hint func NewHint(key interface{}, value interface{}) *Hint { return &Hint{ diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go index 9b62a730..13f8ccde 100644 --- a/pkg/orm/db_hints_test.go +++ b/pkg/orm/db_hints_test.go @@ -67,3 +67,10 @@ func TestMaxIdleConnections(t *testing.T) { assert.Equal(t, hint.GetValue(), i) assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) } + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) +} \ No newline at end of file diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 441fcfc0..b2f1e693 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -59,10 +59,8 @@ import ( "errors" "fmt" "github.com/astaxie/beego/pkg/common" - lru "github.com/hashicorp/golang-lru" "os" "reflect" - "sync" "time" "github.com/astaxie/beego/logs" @@ -612,51 +610,11 @@ func NewOrm() Ormer { // NewOrmWithDB create a new ormer object with specify *sql.DB for query func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err } - kvs := common.NewKVs(params...) - - var stmtCache *lru.Cache - var stmtCacheSize int - - maxStmtCacheSize := kvs.GetValueOr(MaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) - if errC != nil { - return nil, errC - } else { - stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize - } - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } - - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - o := new(orm) o.alias = al