1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-24 18:40:56 +00:00

Refactor RegisterDatabase

This commit is contained in:
Ming Deng 2020-07-20 15:23:17 +00:00
parent 41feb3a711
commit 44460bc457
8 changed files with 279 additions and 99 deletions

View File

@ -12,16 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Deprecated: we will remove this package, please using pkg/orm
package orm package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
lru "github.com/hashicorp/golang-lru"
"reflect" "reflect"
"sync" "sync"
"time" "time"
lru "github.com/hashicorp/golang-lru"
"github.com/astaxie/beego/pkg/common"
orm2 "github.com/astaxie/beego/pkg/orm"
) )
// DriverType database driver constant int. // DriverType database driver constant int.
@ -63,7 +68,7 @@ var (
"tidb": DRTiDB, "tidb": DRTiDB,
"oracle": DROracle, "oracle": DROracle,
"oci8": DROracle, // github.com/mattn/go-oci8 "oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, //https://github.com/rana/ora "ora": DROracle, // https://github.com/rana/ora
} }
dbBasers = map[DriverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(), DRMySQL: newdbBaseMysql(),
@ -119,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
return d.DB.BeginTx(ctx, opts) return d.DB.BeginTx(ctx, opts)
} }
//su must call release to release *sql.Stmt after using // su must call release to release *sql.Stmt after using
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.RLock() d.RLock()
c, ok := d.stmtDecorators.Get(query) c, ok := d.stmtDecorators.Get(query)
@ -289,82 +294,26 @@ func detectTZ(al *alias) {
} }
} }
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmtDecorators: newStmtDecoratorLruWithEvict(),
}
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
if !dataBaseCache.add(aliasName, al) {
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
// AddAliasWthDB add a aliasName for the drivename // AddAliasWthDB add a aliasName for the drivename
// Deprecated: please using pkg/orm
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db) return orm2.AddAliasWthDB(aliasName, driverName, db)
return err
} }
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var ( kvs := make([]common.KV, 0, 2)
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params { for i, v := range params {
switch i { switch i {
case 0: case 0:
SetMaxIdleConns(al.Name, v) kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v})
case 1: case 1:
SetMaxOpenConns(al.Name, v) kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v})
case 2:
kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond})
} }
} }
return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...)
end:
if err != nil {
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
}
return err
} }
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
@ -424,7 +373,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} }
type stmtDecorator struct { type stmtDecorator struct {
wg sync.WaitGroup wg sync.WaitGroup
stmt *sql.Stmt stmt *sql.Stmt
} }
@ -444,7 +393,7 @@ func (s *stmtDecorator) release() {
s.wg.Done() s.wg.Done()
} }
//garbage recycle for stmt // garbage recycle for stmt
func (s *stmtDecorator) destroy() { func (s *stmtDecorator) destroy() {
go func() { go func() {
s.wg.Wait() s.wg.Wait()

View File

@ -0,0 +1,46 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"os"
"testing"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
)
var DBARGS = struct {
Driver string
Source string
Debug string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
os.Getenv("ORM_DEBUG"),
}
func TestRegisterDataBase(t *testing.T) {
err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source)
assert.Nil(t, err)
err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20)
assert.Nil(t, err)
err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300)
assert.Nil(t, err)
err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000)
assert.Nil(t, err)
}

69
pkg/common/kv.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
// KV is common structure to store key-value data.
// when you need something like Pair, you can use this
type KV struct {
Key interface{}
Value interface{}
}
// KVs will store KV collection as map
type KVs struct {
kvs map[interface{}]interface{}
}
// GetValueOr check whether this contains the key,
// if the key not found, the default value will be return
func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
v, ok := kvs.kvs[key]
if ok {
return v
}
return defValue
}
// Contains will check whether contains the key
func (kvs *KVs) Contains(key interface{}) bool {
_, ok := kvs.kvs[key]
return ok
}
// IfContains is a functional API that if the key is in KVs, the action will be invoked
func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs {
v, ok := kvs.kvs[key]
if ok {
action(v)
}
return kvs
}
// Put store the value
func (kvs *KVs) Put(key interface{}, value interface{}) *KVs {
kvs.kvs[key] = value
return kvs
}
// NewKVs will create the *KVs instance
func NewKVs(kvs ...KV) *KVs {
res := &KVs{
kvs: make(map[interface{}]interface{}, len(kvs)),
}
for _, kv := range kvs {
res.kvs[kv.Key] = kv.Value
}
return res
}

40
pkg/common/kv_test.go Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestKVs(t *testing.T) {
key := "my-key"
kvs := NewKVs(KV{
Key: key,
Value: 12,
})
assert.True(t, kvs.Contains(key))
kvs.IfContains(key, func(value interface{}) {
kvs.Put("my-key1", "")
})
assert.True(t, kvs.Contains("my-key1"))
v := kvs.GetValueOr(key, 13)
assert.Equal(t, 12, v)
}

21
pkg/orm/constant.go Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
const (
MaxIdleConnsKey = "MaxIdleConns"
MaxOpenConnsKey = "MaxOpenConns"
ConnMaxLifetimeKey = "ConnMaxLifetime"
)

View File

@ -18,10 +18,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
lru "github.com/hashicorp/golang-lru"
"reflect"
"sync" "sync"
"time" "time"
lru "github.com/hashicorp/golang-lru"
"github.com/astaxie/beego/pkg/common"
) )
// DriverType database driver constant int. // DriverType database driver constant int.
@ -63,7 +65,7 @@ var (
"tidb": DRTiDB, "tidb": DRTiDB,
"oracle": DROracle, "oracle": DROracle,
"oci8": DROracle, // github.com/mattn/go-oci8 "oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, //https://github.com/rana/ora "ora": DROracle, // https://github.com/rana/ora
} }
dbBasers = map[DriverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(), DRMySQL: newdbBaseMysql(),
@ -122,7 +124,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
return d.DB.BeginTx(ctx, opts) return d.DB.BeginTx(ctx, opts)
} }
//su must call release to release *sql.Stmt after using // su must call release to release *sql.Stmt after using
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.RLock() d.RLock()
c, ok := d.stmtDecorators.Get(query) c, ok := d.stmtDecorators.Get(query)
@ -274,16 +276,17 @@ func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interf
} }
type alias struct { type alias struct {
Name string Name string
Driver DriverType Driver DriverType
DriverName string DriverName string
DataSource string DataSource string
MaxIdleConns int MaxIdleConns int
MaxOpenConns int MaxOpenConns int
DB *DB ConnMaxLifetime time.Duration
DbBaser dbBaser DB *DB
TZ *time.Location DbBaser dbBaser
Engine string TZ *time.Location
Engine string
} }
func detectTZ(al *alias) { func detectTZ(al *alias) {
@ -378,13 +381,15 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
} }
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error {
var ( var (
err error err error
db *sql.DB db *sql.DB
al *alias al *alias
) )
kvs := common.NewKVs(params...)
db, err = sql.Open(driverName, dataSource) db, err = sql.Open(driverName, dataSource)
if err != nil { if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
@ -400,14 +405,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) e
detectTZ(al) detectTZ(al)
for i, v := range params { kvs.IfContains(MaxIdleConnsKey, func(value interface{}) {
switch i { SetMaxIdleConns(al.Name, value.(int))
case 0: }).IfContains(MaxOpenConnsKey, func(value interface{}) {
SetMaxIdleConns(al.Name, v) SetMaxOpenConns(al.Name, value.(int))
case 1: }).IfContains(ConnMaxLifetimeKey, func(value interface{}) {
SetMaxOpenConns(al.Name, v) SetConnMaxLifetime(al.Name, value.(time.Duration))
} })
}
end: end:
if err != nil { if err != nil {
@ -454,10 +458,12 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns al.MaxOpenConns = maxOpenConns
al.DB.DB.SetMaxOpenConns(maxOpenConns) al.DB.DB.SetMaxOpenConns(maxOpenConns)
// for tip go 1.2 }
if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) {
} al := getDbAlias(aliasName)
al.ConnMaxLifetime = lifeTime
al.DB.DB.SetConnMaxLifetime(lifeTime)
} }
// GetDB Get *sql.DB from registered database by db alias name. // GetDB Get *sql.DB from registered database by db alias name.
@ -477,7 +483,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} }
type stmtDecorator struct { type stmtDecorator struct {
wg sync.WaitGroup wg sync.WaitGroup
stmt *sql.Stmt stmt *sql.Stmt
} }
@ -497,7 +503,7 @@ func (s *stmtDecorator) release() {
s.wg.Done() s.wg.Done()
} }
//garbage recycle for stmt // garbage recycle for stmt
func (s *stmtDecorator) destroy() { func (s *stmtDecorator) destroy() {
go func() { go func() {
s.wg.Wait() s.wg.Wait()

44
pkg/orm/db_alias_test.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/common"
)
func TestRegisterDataBase(t *testing.T) {
err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{
Key: MaxIdleConnsKey,
Value: 20,
}, common.KV{
Key: MaxOpenConnsKey,
Value: 300,
}, common.KV{
Key: ConnMaxLifetimeKey,
Value: time.Minute,
})
assert.Nil(t, err)
al := getDbAlias("test-params")
assert.NotNil(t, al)
assert.Equal(t, al.MaxIdleConns, 20)
assert.Equal(t, al.MaxOpenConns, 300)
assert.Equal(t, al.ConnMaxLifetime, time.Minute)
}

View File

@ -27,6 +27,8 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now // As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb" // _ "github.com/pingcap/tidb"
"github.com/astaxie/beego/pkg/common"
) )
// A slice string field. // A slice string field.
@ -487,7 +489,10 @@ func init() {
os.Exit(2) os.Exit(2)
} }
err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{
Key:MaxIdleConnsKey,
Value:20,
})
if err != nil{ if err != nil{
panic(fmt.Sprintf("can not register database: %v", err)) panic(fmt.Sprintf("can not register database: %v", err))