diff --git a/context/input.go b/context/input.go index 385549c1..c2c1c63d 100644 --- a/context/input.go +++ b/context/input.go @@ -327,6 +327,26 @@ func (input *BeegoInput) ResetParams() { input.pvalues = input.pvalues[:0] } +// ResetData: reset data +func (input *BeegoInput) ResetData() { + input.dataLock.Lock() + input.data = nil + input.dataLock.Unlock() +} + +// ResetBody: reset body +func (input *BeegoInput) ResetBody() { + input.RequestBody = []byte{} +} + +// Clear: clear all data in input +func (input *BeegoInput) Clear() { + input.ResetParams() + input.ResetData() + input.ResetBody() + +} + // Query returns input data item string by a given string. func (input *BeegoInput) Query(key string) string { if val := input.Param(key); val != "" { diff --git a/context/output.go b/context/output.go index 238dcf45..eaa75720 100644 --- a/context/output.go +++ b/context/output.go @@ -50,9 +50,15 @@ func NewOutput() *BeegoOutput { // Reset init BeegoOutput func (output *BeegoOutput) Reset(ctx *Context) { output.Context = ctx + output.Clear() +} + +// Clear: clear all data in output +func (output *BeegoOutput) Clear() { output.Status = 0 } + // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 508e6b5c..86a50132 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -14,14 +14,29 @@ package common -// KV is common structure to store key-value data. +type KV interface { + GetKey() interface{} + GetValue() interface{} +} + +// SimpleKV is common structure to store key-value data. // when you need something like Pair, you can use this -type KV struct { +type SimpleKV struct { Key interface{} Value interface{} } -// KVs will store KV collection as map +var _ KV = new(SimpleKV) + +func (s *SimpleKV) GetKey() interface{} { + return s.Key +} + +func (s *SimpleKV) GetValue() interface{} { + return s.Value +} + +// KVs will store SimpleKV collection as map type KVs struct { kvs map[interface{}]interface{} } @@ -63,7 +78,7 @@ func NewKVs(kvs ...KV) *KVs { kvs: make(map[interface{}]interface{}, len(kvs)), } for _, kv := range kvs { - res.kvs[kv.Key] = kv.Value + res.kvs[kv.GetKey()] = kv.GetValue() } return res } diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index 45adf5ff..275c6753 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -22,7 +22,7 @@ import ( func TestKVs(t *testing.T) { key := "my-key" - kvs := NewKVs(KV{ + kvs := NewKVs(&SimpleKV{ Key: key, Value: 12, }) diff --git a/pkg/orm/constant.go b/pkg/orm/constant.go deleted file mode 100644 index 14f40a7b..00000000 --- a/pkg/orm/constant.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -const ( - MaxIdleConnsKey = "MaxIdleConns" - MaxOpenConnsKey = "MaxOpenConns" - ConnMaxLifetimeKey = "ConnMaxLifetime" -) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index a3f2a0b9..5f1e3ea3 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -109,8 +109,9 @@ func (ac *_dbCache) getDefault() (al *alias) { type DB struct { *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache + DB *sql.DB + stmtDecorators *lru.Cache + stmtDecoratorsLimit int } var _ dbQuerier = new(DB) @@ -165,16 +166,14 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error } func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) + return d.ExecContext(context.Background(), query, args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if d.stmtDecorators == nil { + return d.DB.ExecContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -185,16 +184,14 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) + return d.QueryContext(context.Background(), query, args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if d.stmtDecorators == nil { + return d.DB.QueryContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { return nil, err @@ -205,24 +202,21 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - + return d.QueryRowContext(context.Background(), query, args...) } func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + if d.stmtDecorators == nil { + return d.DB.QueryRowContext(ctx, query, args...) + } + sd, err := d.getStmtDecorator(query) if err != nil { panic(err) } stmt := sd.getStmt() defer sd.release() - return stmt.QueryRowContext(ctx, args) + return stmt.QueryRowContext(ctx, args...) } type TxDB struct { @@ -345,14 +339,49 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { + existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + if _, ok := dataBaseCache.get(aliasName); ok { + return nil, existErr + } + + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + if !dataBaseCache.add(aliasName, al) { + return nil, existErr + } + + return al, nil +} + +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ + kvs := common.NewKVs(params...) + + var stmtCache *lru.Cache + var stmtCacheSize int + + maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) + if maxStmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = maxStmtCacheSize + } + } + al := new(alias) al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: stmtCache, + stmtDecoratorsLimit: stmtCacheSize, } if dr, ok := drivers[driverName]; ok { @@ -367,28 +396,39 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } + detectTZ(al) + + kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxIdleConns(al, m) + } + }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + if m, ok := value.(int); ok { + SetMaxOpenConns(al, m) + } + }).IfContains(connMaxLifetimeKey, func(value interface{}) { + if m, ok := value.(time.Duration); ok { + SetConnMaxLifetime(al, m) + } + }) return al, nil } // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error { + _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { var ( err error db *sql.DB al *alias ) - kvs := common.NewKVs(params...) db, err = sql.Open(driverName, dataSource) if err != nil { @@ -396,23 +436,13 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...common goto end } - al, err = addAliasWthDB(aliasName, driverName, db) + al, err = addAliasWthDB(aliasName, driverName, db, hints...) if err != nil { goto end } al.DataSource = dataSource - detectTZ(al) - - kvs.IfContains(MaxIdleConnsKey, func(value interface{}) { - SetMaxIdleConns(al.Name, value.(int)) - }).IfContains(MaxOpenConnsKey, func(value interface{}) { - SetMaxOpenConns(al.Name, value.(int)) - }).IfContains(ConnMaxLifetimeKey, func(value interface{}) { - SetConnMaxLifetime(al.Name, value.(time.Duration)) - }) - end: if err != nil { if db != nil { @@ -447,21 +477,18 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) +func SetMaxIdleConns(al *alias, maxIdleConns int) { al.MaxIdleConns = maxIdleConns al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) +func SetMaxOpenConns(al *alias, maxOpenConns int) { al.MaxOpenConns = maxOpenConns al.DB.DB.SetMaxOpenConns(maxOpenConns) } -func SetConnMaxLifetime(aliasName string, lifeTime time.Duration) { - al := getDbAlias(aliasName) +func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { al.ConnMaxLifetime = lifeTime al.DB.DB.SetConnMaxLifetime(lifeTime) } @@ -517,9 +544,12 @@ func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { } } -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { +func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { + cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { value.(*stmtDecorator).destroy() }) - return cache + if err != nil { + return nil, err + } + return cache, nil } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 81b623c8..111657d7 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -19,21 +19,13 @@ import ( "time" "github.com/stretchr/testify/assert" - - "github.com/astaxie/beego/pkg/common" ) func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }, common.KV{ - Key: MaxOpenConnsKey, - Value: 300, - }, common.KV{ - Key: ConnMaxLifetimeKey, - Value: time.Minute, - }) + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -43,6 +35,47 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.ConnMaxLifetime, time.Minute) } +func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 1) +} + +func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { + aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) + assert.Nil(t, err) + + al := getDbAlias(aliasName) + assert.NotNil(t, al) + assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) +} + + func TestDBCache(t *testing.T) { dataBaseCache.add("test1", &alias{}) dataBaseCache.add("default", &alias{}) diff --git a/pkg/orm/db_hints.go b/pkg/orm/db_hints.go new file mode 100644 index 00000000..551c7357 --- /dev/null +++ b/pkg/orm/db_hints.go @@ -0,0 +1,74 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/common" + "time" +) + +type Hint struct { + key interface{} + value interface{} +} + +var _ common.KV = new(Hint) + +// GetKey return key +func (s *Hint) GetKey() interface{} { + return s.key +} + +// GetValue return value +func (s *Hint) GetValue() interface{} { + return s.value +} + +const ( + maxIdleConnectionsKey = "MaxIdleConnections" + maxOpenConnectionsKey = "MaxOpenConnections" + connMaxLifetimeKey = "ConnMaxLifetime" + maxStmtCacheSizeKey = "MaxStmtCacheSize" +) + +var _ common.KV = new(Hint) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(v int) *Hint { + return NewHint(maxIdleConnectionsKey, v) +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(v int) *Hint { + return NewHint(maxOpenConnectionsKey, v) +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) *Hint { + return NewHint(connMaxLifetimeKey, v) +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) *Hint { + return NewHint(maxStmtCacheSizeKey, v) +} + +// NewHint return a hint +func NewHint(key interface{}, value interface{}) *Hint { + return &Hint{ + key: key, + value: value, + } +} diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go new file mode 100644 index 00000000..13f8ccde --- /dev/null +++ b/pkg/orm/db_hints_test.go @@ -0,0 +1,76 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewHint_time(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_int(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_float(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestMaxOpenConnections(t *testing.T) { + i := 887423 + hint := MaxOpenConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey) +} + +func TestConnMaxLifetime(t *testing.T) { + i := time.Hour + hint := ConnMaxLifetime(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), connMaxLifetimeKey) +} + +func TestMaxIdleConnections(t *testing.T) { + i := 42316 + hint := MaxIdleConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) +} + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) +} \ No newline at end of file diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 4c00050d..ae166dc7 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -28,7 +28,6 @@ import ( // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" - "github.com/astaxie/beego/pkg/common" ) // A slice string field. @@ -489,10 +488,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ - Key: MaxIdleConnsKey, - Value: 20, - }) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 3b94ab6c..7cbfec09 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -58,9 +58,9 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/common" "os" "reflect" - "sync" "time" "github.com/astaxie/beego/logs" @@ -580,7 +580,7 @@ func NewOrm() Ormer { return NewOrmUsingDB(`default`) } -// NewOrm create new orm with the name +// NewOrmUsingDB create new orm with the name func NewOrmUsingDB(aliasName string) Ormer { o := new(orm) if al, ok := dataBaseCache.get(aliasName); ok { @@ -597,27 +597,12 @@ func NewOrmUsingDB(aliasName string) Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { + al, err := newAliasWithDb(aliasName, driverName, db, params...) + if err != nil { + return nil, err } - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - detectTZ(al) - o := new(orm) o.alias = al diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 5e05eded..2f214f93 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -32,7 +32,8 @@ func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { if o.closed { return nil, ErrStmtClosed } - return o.stmt.Exec(args...) + flatParams := getFlatParams(nil, args, o.rs.orm.alias.TZ) + return o.stmt.Exec(flatParams...) } func (o *rawPrepare) Close() error { diff --git a/router.go b/router.go index 6a8ac6f7..92316480 100644 --- a/router.go +++ b/router.go @@ -319,6 +319,10 @@ func (p *ControllerRegister) GetContext() *beecontext.Context { // GiveBackContext put the ctx into pool so that it could be reuse func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + // clear input cached data + ctx.Input.Clear() + // clear output cached data + ctx.Output.Clear() p.pool.Put(ctx) }