From 44460bc4570b58cefd7b4b1c65e8a1610ceefcbc Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 20 Jul 2020 15:23:17 +0000 Subject: [PATCH] Refactor RegisterDatabase --- orm/db_alias.go | 87 ++++++++----------------------------- orm/orm_alias_adapt_test.go | 46 ++++++++++++++++++++ pkg/common/kv.go | 69 +++++++++++++++++++++++++++++ pkg/common/kv_test.go | 40 +++++++++++++++++ pkg/orm/constant.go | 21 +++++++++ pkg/orm/db_alias.go | 64 ++++++++++++++------------- pkg/orm/db_alias_test.go | 44 +++++++++++++++++++ pkg/orm/models_test.go | 7 ++- 8 files changed, 279 insertions(+), 99 deletions(-) create mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/common/kv.go create mode 100644 pkg/common/kv_test.go create mode 100644 pkg/orm/constant.go create mode 100644 pkg/orm/db_alias_test.go diff --git a/orm/db_alias.go b/orm/db_alias.go index bf6c350c..a84070b4 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,16 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" + orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -63,7 +68,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -119,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -289,82 +294,26 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - // AddAliasWthDB add a aliasName for the drivename +// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err + return orm2.AddAliasWthDB(aliasName, driverName, db) } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - + kvs := make([]common.KV, 0, 2) for i, v := range params { switch i { case 0: - SetMaxIdleConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) case 1: - SetMaxOpenConns(al.Name, v) + kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) + case 2: + kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) } } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err + return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -424,7 +373,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -444,7 +393,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go new file mode 100644 index 00000000..d7724527 --- /dev/null +++ b/orm/orm_alias_adapt_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" +) + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) + assert.Nil(t, err) + err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) + assert.Nil(t, err) +} diff --git a/pkg/common/kv.go b/pkg/common/kv.go new file mode 100644 index 00000000..508e6b5c --- /dev/null +++ b/pkg/common/kv.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +// KV is common structure to store key-value data. +// when you need something like Pair, you can use this +type KV struct { + Key interface{} + Value interface{} +} + +// KVs will store KV collection as map +type KVs struct { + kvs map[interface{}]interface{} +} + +// GetValueOr check whether this contains the key, +// if the key not found, the default value will be return +func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { + v, ok := kvs.kvs[key] + if ok { + return v + } + return defValue +} + +// Contains will check whether contains the key +func (kvs *KVs) Contains(key interface{}) bool { + _, ok := kvs.kvs[key] + return ok +} + +// IfContains is a functional API that if the key is in KVs, the action will be invoked +func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { + v, ok := kvs.kvs[key] + if ok { + action(v) + } + return kvs +} + +// Put store the value +func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { + kvs.kvs[key] = value + return kvs +} + +// NewKVs will create the *KVs instance +func NewKVs(kvs ...KV) *KVs { + res := &KVs{ + kvs: make(map[interface{}]interface{}, len(kvs)), + } + for _, kv := range kvs { + res.kvs[kv.Key] = kv.Value + } + return res +} diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go new file mode 100644 index 00000000..ed7dc7ef --- /dev/null +++ b/pkg/common/kv_test.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKVs(t *testing.T) { + key := "my-key" + kvs := NewKVs(KV{ + Key: key, + Value: 12, + }) + + assert.True(t, kvs.Contains(key)) + + kvs.IfContains(key, func(value interface{}) { + kvs.Put("my-key1", "") + }) + + assert.True(t, kvs.Contains("my-key1")) + + v := kvs.GetValueOr(key, 13) + assert.Equal(t, 12, v) +} diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go new file mode 100644 index 00000000..14f40a7b --- /dev/null +++ b/pkg/orm/constant.go @@ -0,0 +1,21 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +const ( + MaxIdleConnsKey = "MaxIdleConns" + MaxOpenConnsKey = "MaxOpenConns" + ConnMaxLifetimeKey = "ConnMaxLifetime" +) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index b2a72f56..90c5de3c 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,10 +18,12 @@ import ( "context" "database/sql" "fmt" - lru "github.com/hashicorp/golang-lru" - "reflect" "sync" "time" + + lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/common" ) // DriverType database driver constant int. @@ -63,7 +65,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora + "ora": DROracle, // https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -122,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -//su must call release to release *sql.Stmt after using +// su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -274,16 +276,17 @@ func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interf } type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *DB - DbBaser dbBaser - TZ *time.Location - Engine string + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string } func detectTZ(al *alias) { @@ -378,13 +381,15 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { var ( err error db *sql.DB al *alias ) + kvs := common.NewKVs(params...) + db, err = sql.Open(driverName, dataSource) if err != nil { err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) @@ -400,14 +405,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) e detectTZ(al) - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } + kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { + SetMaxIdleConns(al.Name, value.(int)) + }).IfContains(MaxOpenConnsKey, func(value interface{}) { + SetMaxOpenConns(al.Name, value.(int)) + }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { + SetConnMaxLifetime(al.Name, value.(time.Duration)) + }) end: if err != nil { @@ -454,10 +458,12 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } +} + +func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { + al := getDbAlias(aliasName) + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) } // GetDB Get *sql.DB from registered database by db alias name. @@ -477,7 +483,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -497,7 +503,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -//garbage recycle for stmt +// garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go new file mode 100644 index 00000000..a0cdcd44 --- /dev/null +++ b/pkg/orm/db_alias_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/common" +) + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ + Key: MaxIdleConnsKey, + Value: 20, + }, common.KV{ + Key: MaxOpenConnsKey, + Value: 300, + }, common.KV{ + Key: ConnMaxLifetimeKey, + Value: time.Minute, + }) + assert.Nil(t, err) + + al := getDbAlias("test-params") + assert.NotNil(t, al) + assert.Equal(t, al.MaxIdleConns, 20) + assert.Equal(t, al.MaxOpenConns, 300) + assert.Equal(t, al.ConnMaxLifetime, time.Minute) +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 79e926d3..f14ee9cf 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -27,6 +27,8 @@ import ( _ "github.com/mattn/go-sqlite3" // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" + + "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -487,7 +489,10 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ + Key:MaxIdleConnsKey, + Value:20, + }) if err != nil{ panic(fmt.Sprintf("can not register database: %v", err))