From 5a1fa4e1ec36e874b48e29afa9ab03eee39c2d80 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 10 Aug 2020 18:46:16 +0800 Subject: [PATCH 1/5] specify index --- pkg/common/kv.go | 38 +++++-- pkg/orm/db.go | 64 +++++++++-- pkg/orm/db_alias.go | 13 +-- pkg/orm/db_alias_test.go | 15 +-- pkg/orm/db_hints_test.go | 76 ------------- pkg/orm/db_oracle.go | 24 +++++ pkg/orm/db_postgres.go | 7 ++ pkg/orm/db_sqlite.go | 21 ++++ pkg/orm/db_tables.go | 9 ++ pkg/orm/do_nothing_orm.go | 5 +- pkg/orm/filter_orm_decorator.go | 5 +- pkg/orm/filter_orm_decorator_test.go | 3 +- pkg/orm/{ => hints}/db_hints.go | 80 +++++++++++--- pkg/orm/hints/db_hints_test.go | 154 +++++++++++++++++++++++++++ pkg/orm/models_test.go | 3 +- pkg/orm/orm.go | 44 ++++---- pkg/orm/orm_log.go | 25 +---- pkg/orm/orm_queryset.go | 28 ++++- pkg/orm/orm_test.go | 27 +++-- pkg/orm/types.go | 47 +++++--- 20 files changed, 499 insertions(+), 189 deletions(-) delete mode 100644 pkg/orm/db_hints_test.go rename pkg/orm/{ => hints}/db_hints.go (50%) create mode 100644 pkg/orm/hints/db_hints_test.go diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 8468f4fe..26e786f9 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -36,14 +36,25 @@ func (s *SimpleKV) GetValue() interface{} { return s.Value } -// KVs will store SimpleKV collection as map -type KVs struct { +// KVs interface +type KVs interface { + GetValueOr(key interface{}, defValue interface{}) interface{} + Contains(key interface{}) bool + IfContains(key interface{}, action func(value interface{})) KVs + Put(key interface{}, value interface{}) KVs + Clone() KVs +} + +// SimpleKVs will store SimpleKV collection as map +type SimpleKVs struct { kvs map[interface{}]interface{} } +var _ KVs = new(SimpleKVs) + // GetValueOr returns the value for a given key, if non-existant // it returns defValue -func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { +func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} { v, ok := kvs.kvs[key] if ok { return v @@ -52,13 +63,13 @@ func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { } // Contains checks if a key exists -func (kvs *KVs) Contains(key interface{}) bool { +func (kvs *SimpleKVs) Contains(key interface{}) bool { _, ok := kvs.kvs[key] return ok } // IfContains invokes the action on a key if it exists -func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { +func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs { v, ok := kvs.kvs[key] if ok { action(v) @@ -67,14 +78,25 @@ func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs } // Put stores the value -func (kvs *KVs) Put(key interface{}, value interface{}) *KVs { +func (kvs *SimpleKVs) Put(key interface{}, value interface{}) KVs { kvs.kvs[key] = value return kvs } +// Clone +func (kvs *SimpleKVs) Clone() KVs { + newKVs := new(SimpleKVs) + + for key, value := range kvs.kvs { + newKVs.Put(key, value) + } + + return newKVs +} + // NewKVs creates the *KVs instance -func NewKVs(kvs ...KV) *KVs { - res := &KVs{ +func NewKVs(kvs ...KV) KVs { + res := &SimpleKVs{ kvs: make(map[interface{}]interface{}, len(kvs)), } for _, kv := range kvs { diff --git a/pkg/orm/db.go b/pkg/orm/db.go index 9a1827e8..573247f0 100644 --- a/pkg/orm/db.go +++ b/pkg/orm/db.go @@ -18,6 +18,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "reflect" "strings" "time" @@ -738,8 +739,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } tables := newDbTables(mi, d.ins) + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) } where, args := tables.getCondSQL(cond, false, tz) @@ -790,9 +793,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con sets := strings.Join(cols, ", ") + " " if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) + query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where) } else { - supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s", + Q, mi.fields.pk.column, Q, + Q, mi.table, Q, + specifyIndexes, join, where) query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) } @@ -843,8 +849,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con tables := newDbTables(mi, d.ins) tables.skipEnd = true + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) } if cond == nil || cond.IsEmpty() { @@ -857,7 +865,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con join := tables.getJoinSQL() cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where) d.ins.ReplaceMarks(&query) @@ -1002,6 +1010,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, offset, rlimit) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) for _, tbl := range tables.tables { if tbl.sel { @@ -1015,9 +1024,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", + sqlSelect, sels, Q, mi.table, Q, + specifyIndexes, join, where, groupBy, orderBy, limit) - if qs.forupdate { + if qs.forUpdate { query += " FOR UPDATE" } @@ -1153,10 +1164,13 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition groupBy := tables.getGroupSQL(qs.groups) tables.getOrderSQL(qs.orders) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) Q := d.ins.TableQuote() - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", + Q, mi.table, Q, + specifyIndexes, join, where, groupBy) if groupBy != "" { query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) @@ -1680,6 +1694,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, qs.offset, qs.limit) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) sels := strings.Join(cols, ", ") @@ -1687,7 +1702,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond if qs.distinct { sqlSelect += " DISTINCT" } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", + sqlSelect, sels, + Q, mi.table, Q, + specifyIndexes, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -1781,10 +1799,6 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } -func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { - return 0, nil -} - // flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true @@ -1900,3 +1914,31 @@ func (d *dbBase) ShowColumnsQuery(table string) string { func (d *dbBase) IndexExists(dbQuerier, string, string) bool { panic(ErrNotImplement) } + +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var useWay string + + switch useIndex { + case hints.KeyUseIndex: + useWay = `USE` + case hints.KeyForceIndex: + useWay = `FORCE` + case hints.KeyIgnoreIndex: + useWay = `IGNORE` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`)) +} + + diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 5f1e3ea3..93f282af 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "sync" "time" @@ -363,7 +364,7 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) + maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) if maxStmtCacheSize > 0 { _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) if errC != nil { @@ -398,15 +399,15 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K detectTZ(al) - kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { + kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { if m, ok := value.(int); ok { SetMaxIdleConns(al, m) } - }).IfContains(maxOpenConnectionsKey, func(value interface{}) { + }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { if m, ok := value.(int); ok { SetMaxOpenConns(al, m) } - }).IfContains(connMaxLifetimeKey, func(value interface{}) { + }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { if m, ok := value.(time.Duration); ok { SetConnMaxLifetime(al, m) } @@ -422,7 +423,7 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { var ( err error db *sql.DB @@ -436,7 +437,7 @@ func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common. goto end } - al, err = addAliasWthDB(aliasName, driverName, db, hints...) + al, err = addAliasWthDB(aliasName, driverName, db, params...) if err != nil { goto end } diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 111657d7..576214fc 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -15,6 +15,7 @@ package orm import ( + "github.com/astaxie/beego/pkg/orm/hints" "testing" "time" @@ -23,9 +24,9 @@ import ( func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - MaxIdleConnections(20), - MaxOpenConnections(300), - ConnMaxLifetime(time.Minute)) + hints.MaxIdleConnections(20), + hints.MaxOpenConnections(300), + hints.ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -37,7 +38,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -47,7 +48,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -57,7 +58,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -67,7 +68,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go deleted file mode 100644 index 13f8ccde..00000000 --- a/pkg/orm/db_hints_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestNewHint_time(t *testing.T) { - key := "qweqwe" - value := time.Second - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestNewHint_int(t *testing.T) { - key := "qweqwe" - value := 281230 - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestNewHint_float(t *testing.T) { - key := "qweqwe" - value := 21.2459753 - hint := NewHint(key, value) - - assert.Equal(t, hint.GetKey(), key) - assert.Equal(t, hint.GetValue(), value) -} - -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), connMaxLifetimeKey) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) -} \ No newline at end of file diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go index 5d121f83..fa49e16b 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/orm/db_oracle.go @@ -16,6 +16,7 @@ package orm import ( "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "strings" ) @@ -96,6 +97,29 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var hint string + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + hint = `INDEX` + case hints.KeyIgnoreIndex: + hint = `NO_INDEX` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`)) +} + // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go index c488fb38..cf1a3413 100644 --- a/pkg/orm/db_postgres.go +++ b/pkg/orm/db_postgres.go @@ -92,6 +92,7 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } + // postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` @@ -181,6 +182,12 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo return cnt > 0 } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") + return `` +} + // create new postgresql dbBaser. func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go index 1d62ee34..244aae7a 100644 --- a/pkg/orm/db_sqlite.go +++ b/pkg/orm/db_sqlite.go @@ -17,7 +17,9 @@ package orm import ( "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "reflect" + "strings" "time" ) @@ -153,6 +155,25 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool return false } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`)) + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } +} + + // create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go index 4b21a6fc..d7e99639 100644 --- a/pkg/orm/db_tables.go +++ b/pkg/orm/db_tables.go @@ -472,6 +472,15 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits return } +// getIndexSql generate index sql. +func (t *dbTables) getIndexSql(tableName string,useIndex int, indexes []string) (clause string) { + if len(indexes) == 0 { + return + } + + return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes) +} + // crete new tables collection. func newDbTables(mi *modelInfo, base dbBaser) *dbTables { tables := &dbTables{} diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go index 87b0a2ae..686b7752 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/orm/do_nothing_orm.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" ) // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation @@ -52,11 +53,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, return false, 0, nil } -func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } -func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { return 0, nil } diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go index eb26ea68..2f32d8c6 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/orm/filter_orm_decorator.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" "reflect" "time" ) @@ -133,11 +134,11 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa return ok, res, err } -func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return f.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { var ( res int64 err error diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index d1099eaf..abb8322c 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "errors" + "github.com/astaxie/beego/pkg/common" "sync" "testing" @@ -360,7 +361,7 @@ func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{} return errors.New("read for update error") } -func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { return 99, errors.New("load related error") } diff --git a/pkg/orm/db_hints.go b/pkg/orm/hints/db_hints.go similarity index 50% rename from pkg/orm/db_hints.go rename to pkg/orm/hints/db_hints.go index 551c7357..f708f310 100644 --- a/pkg/orm/db_hints.go +++ b/pkg/orm/hints/db_hints.go @@ -12,13 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package hints import ( "github.com/astaxie/beego/pkg/common" "time" ) +const ( + //db level + KeyMaxIdleConnections = iota + KeyMaxOpenConnections + KeyConnMaxLifetime + KeyMaxStmtCacheSize + + //query level + KeyForceIndex + KeyUseIndex + KeyIgnoreIndex + KeyForUpdate + KeyLimit + KeyOffset + KeyOrderBy + KeyRelDepth +) + type Hint struct { key interface{} value interface{} @@ -36,33 +54,71 @@ func (s *Hint) GetValue() interface{} { return s.value } -const ( - maxIdleConnectionsKey = "MaxIdleConnections" - maxOpenConnectionsKey = "MaxOpenConnections" - connMaxLifetimeKey = "ConnMaxLifetime" - maxStmtCacheSizeKey = "MaxStmtCacheSize" -) - var _ common.KV = new(Hint) // MaxIdleConnections return a hint about MaxIdleConnections func MaxIdleConnections(v int) *Hint { - return NewHint(maxIdleConnectionsKey, v) + return NewHint(KeyMaxIdleConnections, v) } // MaxOpenConnections return a hint about MaxOpenConnections func MaxOpenConnections(v int) *Hint { - return NewHint(maxOpenConnectionsKey, v) + return NewHint(KeyMaxOpenConnections, v) } // ConnMaxLifetime return a hint about ConnMaxLifetime func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(connMaxLifetimeKey, v) + return NewHint(KeyConnMaxLifetime, v) } // MaxStmtCacheSize return a hint about MaxStmtCacheSize func MaxStmtCacheSize(v int) *Hint { - return NewHint(maxStmtCacheSizeKey, v) + return NewHint(KeyMaxStmtCacheSize, v) +} + +// ForceIndex return a hint about ForceIndex +func ForceIndex(indexes ...string) *Hint { + return NewHint(KeyForceIndex, indexes) +} + +// UseIndex return a hint about UseIndex +func UseIndex(indexes ...string) *Hint { + return NewHint(KeyUseIndex, indexes) +} + +// IgnoreIndex return a hint about IgnoreIndex +func IgnoreIndex(indexes ...string) *Hint { + return NewHint(KeyIgnoreIndex, indexes) +} + +// ForUpdate return a hint about ForUpdate +func ForUpdate() *Hint { + return NewHint(KeyForUpdate, true) +} + +// DefaultRelDepth return a hint about DefaultRelDepth +func DefaultRelDepth() *Hint { + return NewHint(KeyRelDepth, true) +} + +// RelDepth return a hint about RelDepth +func RelDepth(d int) *Hint { + return NewHint(KeyRelDepth, d) +} + +// Limit return a hint about Limit +func Limit(d int64) *Hint { + return NewHint(KeyLimit, d) +} + +// Offset return a hint about Offset +func Offset(d int64) *Hint { + return NewHint(KeyOffset, d) +} + +// OrderBy return a hint about OrderBy +func OrderBy(s string) *Hint { + return NewHint(KeyOrderBy, s) } // NewHint return a hint diff --git a/pkg/orm/hints/db_hints_test.go b/pkg/orm/hints/db_hints_test.go new file mode 100644 index 00000000..5ab44b08 --- /dev/null +++ b/pkg/orm/hints/db_hints_test.go @@ -0,0 +1,154 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hints + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewHint_time(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_int(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHint_float(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestMaxOpenConnections(t *testing.T) { + i := 887423 + hint := MaxOpenConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) +} + +func TestConnMaxLifetime(t *testing.T) { + i := time.Hour + hint := ConnMaxLifetime(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) +} + +func TestMaxIdleConnections(t *testing.T) { + i := 42316 + hint := MaxIdleConnections(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) +} + +func TestMaxStmtCacheSize(t *testing.T) { + i := 94157 + hint := MaxStmtCacheSize(i) + assert.Equal(t, hint.GetValue(), i) + assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) +} + +func TestForceIndex(t *testing.T) { + s := []string{`f_index1`, `f_index2`, `f_index3`} + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestForceIndex_0(t *testing.T) { + var s []string + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestIgnoreIndex(t *testing.T) { + s := []string{`i_index1`, `i_index2`, `i_index3`} + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestIgnoreIndex_0(t *testing.T) { + var s []string + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestUseIndex(t *testing.T) { + s := []string{`u_index1`, `u_index2`, `u_index3`} + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestUseIndex_0(t *testing.T) { + var s []string + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestForUpdate(t *testing.T) { + hint := ForUpdate() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyForUpdate) +} + +func TestDefaultRelDepth(t *testing.T) { + hint := DefaultRelDepth() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestRelDepth(t *testing.T) { + hint := RelDepth(157965) + assert.Equal(t, hint.GetValue(), 157965) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestLimit(t *testing.T) { + hint := Limit(1579625) + assert.Equal(t, hint.GetValue(), int64(1579625)) + assert.Equal(t, hint.GetKey(), KeyLimit) +} + +func TestOffset(t *testing.T) { + hint := Offset(int64(1572123965)) + assert.Equal(t, hint.GetValue(), int64(1572123965)) + assert.Equal(t, hint.GetKey(), KeyOffset) +} + +func TestOrderBy(t *testing.T) { + hint := OrderBy(`-ID`) + assert.Equal(t, hint.GetValue(), `-ID`) + assert.Equal(t, hint.GetKey(), KeyOrderBy) +} \ No newline at end of file diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index ae166dc7..935c2073 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -18,6 +18,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "os" "strings" "time" @@ -488,7 +489,7 @@ func init() { os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index d79053af..fb63d4e5 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -59,6 +59,7 @@ import ( "errors" "fmt" "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/orm/hints" "os" "reflect" "time" @@ -99,6 +100,7 @@ type ormBase struct { var _ DQL = new(ormBase) var _ DML = new(ormBase) +var _ DriverGetter = new(ormBase) // get model info and model reflect value func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { @@ -302,11 +304,10 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } - -func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) @@ -314,24 +315,29 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s var relDepth int var limit, offset int64 var order string - for i, arg := range args { - switch i { - case 0: - if v, ok := arg.(bool); ok { - if v { - relDepth = DefaultRelsDepth - } - } else if v, ok := arg.(int); ok { - relDepth = v + + kvs := common.NewKVs(args...) + kvs.IfContains(hints.KeyRelDepth, func(value interface{}) { + if v, ok := value.(bool); ok { + if v { + relDepth = DefaultRelsDepth } - case 1: - limit = ToInt64(arg) - case 2: - offset = ToInt64(arg) - case 3: - order, _ = arg.(string) + } else if v, ok := value.(int); ok { + relDepth = v } - } + }).IfContains(hints.KeyLimit, func(value interface{}) { + if v, ok := value.(int64); ok { + limit = v + } + }).IfContains(hints.KeyOffset, func(value interface{}) { + if v, ok := value.(int64); ok { + offset = v + } + }).IfContains(hints.KeyOrderBy, func(value interface{}) { + if v, ok := value.(string); ok { + order = v + } + }) switch fi.fieldType { case RelOneToOne, RelForeignKey, RelReverseOne: diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go index 5bb3a24f..d8df7e36 100644 --- a/pkg/orm/orm_log.go +++ b/pkg/orm/orm_log.go @@ -127,10 +127,7 @@ var _ txer = new(dbQueryLog) var _ txEnder = new(dbQueryLog) func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.Prepare(query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err + return d.PrepareContext(context.Background(), query) } func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { @@ -141,10 +138,7 @@ func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stm } func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.Exec(query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err + return d.ExecContext(context.Background(), query, args...) } func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { @@ -155,10 +149,7 @@ func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...inte } func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.Query(query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err + return d.QueryContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -169,10 +160,7 @@ func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...int } func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRow(query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res + return d.QueryRowContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { @@ -183,10 +171,7 @@ func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ... } func (d *dbQueryLog) Begin() (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).Begin() - debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) - return tx, err + return d.BeginTx(context.Background(), nil) } func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 83168de7..734fc738 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -17,6 +17,7 @@ package orm import ( "context" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" ) type colValue struct { @@ -71,7 +72,9 @@ type querySet struct { groups []string orders []string distinct bool - forupdate bool + forUpdate bool + useIndex int + indexes []string orm *ormBase ctx context.Context forContext bool @@ -148,7 +151,28 @@ func (o querySet) Distinct() QuerySeter { // add FOR UPDATE to SELECT func (o querySet) ForUpdate() QuerySeter { - o.forupdate = true + o.forUpdate = true + return &o +} + +// ForceIndex force index for query +func (o querySet) ForceIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyForceIndex + o.indexes = indexes + return &o +} + +// UseIndex use index for query +func (o querySet) UseIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyUseIndex + o.indexes = indexes + return &o +} + +// IgnoreIndex ignore index for query +func (o querySet) IgnoreIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyIgnoreIndex + o.indexes = indexes return &o } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index f5242a46..1d173426 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/pkg/orm/hints" "io/ioutil" "math" "os" @@ -1279,24 +1280,32 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - num, err = dORM.LoadRelated(&user, "Posts", true) + num, err = dORM.LoadRelated(&user, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1) + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1)) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.OrderBy("-Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1), + hints.Offset(1), + hints.OrderBy("Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) @@ -1318,7 +1327,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&profile, "User", true) + num, err = dORM.LoadRelated(&profile, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(profile.User == nil, false)) @@ -1335,7 +1344,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile.Age, 30)) - num, err = dORM.LoadRelated(&user, "Profile", true) + num, err = dORM.LoadRelated(&user, "Profile", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(user.Profile == nil, false)) @@ -1355,7 +1364,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&post, "User", true) + num, err = dORM.LoadRelated(&post, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(post.User == nil, false)) @@ -1375,7 +1384,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - num, err = dORM.LoadRelated(&post, "Tags", true) + num, err = dORM.LoadRelated(&post, "Tags", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(post.Tags), 2)) @@ -1396,7 +1405,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - num, err = dORM.LoadRelated(&tag, "Posts", true) + num, err = dORM.LoadRelated(&tag, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 9624fd94..0be2b809 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/pkg/common" "reflect" "time" ) @@ -175,14 +176,14 @@ type DQL interface { // example: // Ormer.LoadRelated(post,"Tags") // for _,tag := range post.Tags{...} - // args[0] bool true useDefaultRelsDepth ; false depth 0 - // args[0] int loadRelationDepth - // args[1] int limit default limit 1000 - // args[2] int offset default offset 0 - // args[3] string order for example : "-Id" + // hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0 + // hints.RelDepth loadRelationDepth + // hints.Limit limit default limit 1000 + // hints.Offset int offset default offset 0 + // hints.OrderBy string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) - LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) // create a models to models queryer // for example: @@ -282,6 +283,21 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add FORCE INDEX expression. + // for example: + // qs.ForceIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + ForceIndex(indexes ...string) QuerySeter + // add USE INDEX expression. + // for example: + // qs.UseIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + UseIndex(indexes ...string) QuerySeter + // add IGNORE INDEX expression. + // for example: + // qs.IgnoreIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + IgnoreIndex(indexes ...string) QuerySeter // set relation model to query together. // it will query relation models and assign to parent model. // for example: @@ -527,24 +543,27 @@ type txEnder interface { // base database struct type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - SupportUpdateJoin() bool UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + + SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -559,4 +578,6 @@ type dbBaser interface { IndexExists(dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) setval(dbQuerier, *modelInfo, []string) error + + GenerateSpecifyIndex(tableName string,useIndex int ,indexes []string) string } From 882f1273c8e9afc26984802cc42a027358087a0d Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 10 Aug 2020 23:27:03 +0800 Subject: [PATCH 2/5] add UT for specifying indexes --- pkg/orm/models_test.go | 9 +++++++++ pkg/orm/orm_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 935c2073..52524501 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -382,6 +382,15 @@ type InLine struct { Email string } +type Index struct { + // Common Fields + Id int `orm:"column(id)"` + + // Other Fields + F1 int `orm:"column(f1);unique"` + F2 int `orm:"column(f2);unique"` +} + func NewInLine() *InLine { return new(InLine) } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 1d173426..58447adb 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -201,6 +201,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(Index)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -225,6 +226,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(Index)) BootStrap() @@ -794,6 +796,32 @@ func TestExpr(t *testing.T) { // throwFail(t, AssertIs(num, 3)) } +func TestSpecifyIndex(t *testing.T) { + var index *Index + index = &Index{ + F1: 1, + F2: 2, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 1)) + + index = &Index{ + F1: 3, + F2: 4, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`f1`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `3`).UseIndex(`f1`, `f2`).One(index) + throwFailNow(t, AssertIs(index.F2, 4)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`f1`, `f2`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) +} + func TestOperators(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.Filter("user_name", "slene").Count() From f8c0e6fec56100a0290d0fd51427e97ef99781fc Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Tue, 11 Aug 2020 00:06:36 +0800 Subject: [PATCH 3/5] fix UT --- pkg/orm/models_test.go | 4 ++-- pkg/orm/orm_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index 52524501..85815edd 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -387,8 +387,8 @@ type Index struct { Id int `orm:"column(id)"` // Other Fields - F1 int `orm:"column(f1);unique"` - F2 int `orm:"column(f2);unique"` + F1 int `orm:"column(f1);index"` + F2 int `orm:"column(f2);index"` } func NewInLine() *InLine { diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index 58447adb..e08b1b12 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -812,13 +812,13 @@ func TestSpecifyIndex(t *testing.T) { _, _ = dORM.Insert(index) throwFailNow(t, AssertIs(index.Id, 2)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`f1`).One(index) + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`index_f1`).One(index) throwFailNow(t, AssertIs(index.F2, 2)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `3`).UseIndex(`f1`, `f2`).One(index) - throwFailNow(t, AssertIs(index.F2, 4)) + _ = dORM.QueryTable(&Index{}).Filter(`f2`, `4`).UseIndex(`index_f2`).One(index) + throwFailNow(t, AssertIs(index.F1, 3)) - _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`f1`, `f2`).One(index) + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`index_f1`, `index_f2`).One(index) throwFailNow(t, AssertIs(index.F2, 2)) } From c22af4c61199ed1e6c664fbf3ff9919201d2c6fa Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 10 Aug 2020 23:04:57 +0800 Subject: [PATCH 4/5] Fix Tracing and prometheus bug --- build_info.go | 8 +- config.go | 4 +- config/config.go | 6 +- config/fake.go | 16 +++ context/output.go | 1 - go.sum | 4 + pkg/admin_test.go | 18 +++- pkg/app.go | 1 + pkg/filter.go | 3 + pkg/filter_chain_test.go | 3 +- pkg/hooks.go | 2 +- pkg/httplib/filter.go | 2 +- pkg/httplib/filter/opentracing/filter.go | 40 ++++---- pkg/httplib/filter/opentracing/filter_test.go | 2 +- pkg/httplib/filter/prometheus/filter.go | 8 +- pkg/httplib/filter/prometheus/filter_test.go | 2 +- pkg/orm/db_alias.go | 3 +- pkg/orm/db_alias_test.go | 1 - pkg/orm/db_hints_test.go | 2 +- pkg/orm/do_nothing_orm.go | 11 ++- ...ing_omr_test.go => do_nothing_orm_test.go} | 2 +- pkg/orm/filter.go | 10 +- pkg/orm/filter/opentracing/filter.go | 38 ++++--- pkg/orm/filter/opentracing/filter_test.go | 4 +- pkg/orm/filter/prometheus/filter.go | 2 +- pkg/orm/filter/prometheus/filter_test.go | 2 +- pkg/orm/filter_orm_decorator.go | 98 ++++++++++--------- pkg/orm/filter_orm_decorator_test.go | 50 +++++----- pkg/orm/filter_test.go | 3 +- pkg/orm/invocation.go | 19 ++-- pkg/orm/model_utils_test.go | 2 +- pkg/orm/models_test.go | 1 - pkg/orm/orm.go | 38 +++---- pkg/orm/orm_test.go | 2 - pkg/orm/types.go | 8 +- pkg/router.go | 5 +- pkg/session/sess_file_test.go | 2 +- pkg/web/doc.go | 2 +- pkg/web/filter/opentracing/filter.go | 23 +++-- pkg/web/filter/opentracing/filter_test.go | 2 +- pkg/web/filter/prometheus/filter.go | 2 +- pkg/web/filter/prometheus/filter_test.go | 2 +- 42 files changed, 257 insertions(+), 197 deletions(-) rename pkg/orm/{do_nothing_omr_test.go => do_nothing_orm_test.go} (99%) diff --git a/build_info.go b/build_info.go index 896bbdf3..59e78127 100644 --- a/build_info.go +++ b/build_info.go @@ -16,15 +16,15 @@ package beego var ( // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildVersion string + BuildVersion string // Deprecated: using pkg/, we will delete this in v2.1.0 BuildGitRevision string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildStatus string + BuildStatus string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTag string + BuildTag string // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTime string + BuildTime string // Deprecated: using pkg/, we will delete this in v2.1.0 GoVersion string diff --git a/config.go b/config.go index d707542a..7917528e 100644 --- a/config.go +++ b/config.go @@ -15,13 +15,13 @@ package beego import ( + "crypto/tls" "fmt" "os" "path/filepath" "reflect" "runtime" "strings" - "crypto/tls" "github.com/astaxie/beego/config" "github.com/astaxie/beego/context" @@ -163,7 +163,7 @@ func init() { } appConfigPath = filepath.Join(WorkPath, "conf", filename) if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { - appConfigPath = configPath + appConfigPath = configPath } if !utils.FileExists(appConfigPath) { appConfigPath = filepath.Join(AppPath, "conf", filename) diff --git a/config/config.go b/config/config.go index f46f862b..db2e96f6 100644 --- a/config/config.go +++ b/config/config.go @@ -51,9 +51,9 @@ import ( // Deprecated: using pkg/config, we will delete this in v2.1.0 type Configer interface { // Deprecated: using pkg/config, we will delete this in v2.1.0 - Set(key, val string) error //support section::key type in given key when using ini type. + Set(key, val string) error //support section::key type in given key when using ini type. // Deprecated: using pkg/config, we will delete this in v2.1.0 - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. // Deprecated: using pkg/config, we will delete this in v2.1.0 Strings(key string) []string //get string slice // Deprecated: using pkg/config, we will delete this in v2.1.0 @@ -65,7 +65,7 @@ type Configer interface { // Deprecated: using pkg/config, we will delete this in v2.1.0 Float(key string) (float64, error) // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. // Deprecated: using pkg/config, we will delete this in v2.1.0 DefaultStrings(key string, defaultVal []string) []string //get string slice // Deprecated: using pkg/config, we will delete this in v2.1.0 diff --git a/config/fake.go b/config/fake.go index 07e56ce2..8093ad61 100644 --- a/config/fake.go +++ b/config/fake.go @@ -27,15 +27,18 @@ type fakeConfigContainer struct { func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Set(key, val string) error { c.data[strings.ToLower(key)] = val return nil } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) String(key string) string { return c.getData(key) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) @@ -44,6 +47,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Strings(key string) []string { v := c.String(key) @@ -52,6 +56,7 @@ func (c *fakeConfigContainer) Strings(key string) []string { } return strings.Split(v, ";") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) @@ -60,10 +65,12 @@ func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) [] } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) @@ -72,10 +79,12 @@ func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) @@ -84,10 +93,12 @@ func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getData(key)) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) @@ -96,10 +107,12 @@ func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) @@ -108,6 +121,7 @@ func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float } return v } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { @@ -115,10 +129,12 @@ func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { } return nil, errors.New("key not find") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } + // Deprecated: using pkg/config, we will delete this in v2.1.0 func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") diff --git a/context/output.go b/context/output.go index eaa75720..7409e4e5 100644 --- a/context/output.go +++ b/context/output.go @@ -58,7 +58,6 @@ func (output *BeegoOutput) Clear() { output.Status = 0 } - // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) diff --git a/go.sum b/go.sum index 12b76333..75247943 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,6 +220,9 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/pkg/admin_test.go b/pkg/admin_test.go index e7eae771..5094aeed 100644 --- a/pkg/admin_test.go +++ b/pkg/admin_test.go @@ -6,10 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/toolbox" ) @@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { t.Errorf("invalid response map length: got %d want %d", len(decodedResponseBody), len(expectedResponseBody)) } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) - if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { - t.Errorf("handler returned unexpected body: got %v want %v", - decodedResponseBody, expectedResponseBody) + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] } + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) + } diff --git a/pkg/app.go b/pkg/app.go index d94d56b5..ea71ce4e 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -498,6 +498,7 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. +// the filter's behavior is like stack func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) return BeeApp diff --git a/pkg/filter.go b/pkg/filter.go index 543d7901..911cb848 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -33,6 +33,7 @@ type FilterFunc func(ctx *context.Context) // when a request with a matching URL arrives. type FilterRouter struct { filterFunc FilterFunc + next *FilterRouter tree *Tree pattern string returnOnOutput bool @@ -81,6 +82,8 @@ func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterPar ctx.Input.SetParam(k, v) } } + } else if f.next != nil { + return f.next.filter(ctx, urlPath, preFilterParams) } if f.returnOnOutput && ctx.ResponseWriter.Started { return true, true diff --git a/pkg/filter_chain_test.go b/pkg/filter_chain_test.go index 42397a60..f1f86088 100644 --- a/pkg/filter_chain_test.go +++ b/pkg/filter_chain_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,7 +39,6 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ctx.Output.Body([]byte("hello")) }) - r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() diff --git a/pkg/hooks.go b/pkg/hooks.go index f511e216..3f778cdc 100644 --- a/pkg/hooks.go +++ b/pkg/hooks.go @@ -111,4 +111,4 @@ func registerCommentRouter() error { } return nil -} \ No newline at end of file +} diff --git a/pkg/httplib/filter.go b/pkg/httplib/filter.go index 72a497d0..5daed64c 100644 --- a/pkg/httplib/filter.go +++ b/pkg/httplib/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/httplib/filter/opentracing/filter.go b/pkg/httplib/filter/opentracing/filter.go index 5f409c63..6cc4d6b0 100644 --- a/pkg/httplib/filter/opentracing/filter.go +++ b/pkg/httplib/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,14 +17,11 @@ package opentracing import ( "context" "net/http" - "strconv" + "github.com/astaxie/beego/pkg/httplib" logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" - - "github.com/astaxie/beego/pkg/httplib" ) type FilterChainBuilder struct { @@ -38,14 +35,8 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { method := req.GetRequest().Method - host := req.GetRequest().URL.Host - path := req.GetRequest().URL.Path - proto := req.GetRequest().Proto - - scheme := req.GetRequest().URL.Scheme - - operationName := host + path + "#" + method + operationName := method + "#" + req.GetRequest().URL.String() span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) defer span.Finish() @@ -54,21 +45,24 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt resp, err := next(spanCtx, req) if resp != nil { - span.SetTag("status", strconv.Itoa(resp.StatusCode)) + span.SetTag("http.status_code", resp.StatusCode) } - - span.SetTag("method", method) - span.SetTag("host", host) - span.SetTag("path", path) - span.SetTag("proto", proto) - span.SetTag("scheme", scheme) - - span.LogFields(log.String("url", req.GetRequest().URL.String())) - + span.SetTag("http.method", method) + span.SetTag("peer.hostname", req.GetRequest().URL.Host) + span.SetTag("http.url", req.GetRequest().URL.String()) + span.SetTag("http.scheme", req.GetRequest().URL.Scheme) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") if err != nil { - span.LogFields(log.String("error", err.Error())) + span.SetTag("error", true) + span.SetTag("message", err.Error()) + } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { + span.SetTag("error", true) } + span.SetTag("peer.address", req.GetRequest().RemoteAddr) + span.SetTag("http.proto", req.GetRequest().Proto) + if builder.CustomSpanFunc != nil { builder.CustomSpanFunc(span, ctx, req, resp, err) } diff --git a/pkg/httplib/filter/opentracing/filter_test.go b/pkg/httplib/filter/opentracing/filter_test.go index aa687541..8849a9ad 100644 --- a/pkg/httplib/filter/opentracing/filter_test.go +++ b/pkg/httplib/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/httplib/filter/prometheus/filter.go b/pkg/httplib/filter/prometheus/filter.go index a0b24d67..e7f7316f 100644 --- a/pkg/httplib/filter/prometheus/filter.go +++ b/pkg/httplib/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -63,11 +63,13 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time host := req.GetRequest().URL.Host path := req.GetRequest().URL.Path - status := resp.StatusCode + status := -1 + if resp != nil { + status = resp.StatusCode + } dur := int(endTime.Sub(startTime) / time.Millisecond) - builder.summaryVec.WithLabelValues(proto, scheme, method, host, path, strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil)) } diff --git a/pkg/httplib/filter/prometheus/filter_test.go b/pkg/httplib/filter/prometheus/filter_test.go index e15d82e5..2964e6c5 100644 --- a/pkg/httplib/filter/prometheus/filter_test.go +++ b/pkg/httplib/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 5f1e3ea3..e9b39a3d 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -357,7 +357,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV)(*alias, error){ +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { kvs := common.NewKVs(params...) var stmtCache *lru.Cache @@ -429,7 +429,6 @@ func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common. al *alias ) - db, err = sql.Open(driverName, dataSource) if err != nil { err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) diff --git a/pkg/orm/db_alias_test.go b/pkg/orm/db_alias_test.go index 111657d7..6275cb2a 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/orm/db_alias_test.go @@ -75,7 +75,6 @@ func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) } - func TestDBCache(t *testing.T) { dataBaseCache.add("test1", &alias{}) dataBaseCache.add("default", &alias{}) diff --git a/pkg/orm/db_hints_test.go b/pkg/orm/db_hints_test.go index 13f8ccde..bb713171 100644 --- a/pkg/orm/db_hints_test.go +++ b/pkg/orm/db_hints_test.go @@ -73,4 +73,4 @@ func TestMaxStmtCacheSize(t *testing.T) { hint := MaxStmtCacheSize(i) assert.Equal(t, hint.GetValue(), i) assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey) -} \ No newline at end of file +} diff --git a/pkg/orm/do_nothing_orm.go b/pkg/orm/do_nothing_orm.go index 87b0a2ae..f460794c 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/orm/do_nothing_orm.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import ( var _ Ormer = new(DoNothingOrm) type DoNothingOrm struct { + } func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { @@ -148,19 +149,19 @@ func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOpti return nil, nil } -func (d *DoNothingOrm) DoTx(task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } -func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return nil } diff --git a/pkg/orm/do_nothing_omr_test.go b/pkg/orm/do_nothing_orm_test.go similarity index 99% rename from pkg/orm/do_nothing_omr_test.go rename to pkg/orm/do_nothing_orm_test.go index 92cde38b..4d477353 100644 --- a/pkg/orm/do_nothing_omr_test.go +++ b/pkg/orm/do_nothing_orm_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter.go b/pkg/orm/filter.go index d04b8c42..03a30022 100644 --- a/pkg/orm/filter.go +++ b/pkg/orm/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import ( // don't forget to call next(...) inside your Filter type FilterChain func(next Filter) Filter -// Filter's behavior is a little big strang. +// Filter's behavior is a little big strange. // it's only be called when users call methods of Ormer type Filter func(ctx context.Context, inv *Invocation) @@ -31,6 +31,6 @@ var globalFilterChains = make([]FilterChain, 0, 4) // AddGlobalFilterChain adds a new FilterChain // All orm instances built after this invocation will use this filterChain, // but instances built before this invocation will not be affected -func AddGlobalFilterChain(filterChain FilterChain) { - globalFilterChains = append(globalFilterChains, filterChain) -} \ No newline at end of file +func AddGlobalFilterChain(filterChain ...FilterChain) { + globalFilterChains = append(globalFilterChains, filterChain...) +} diff --git a/pkg/orm/filter/opentracing/filter.go b/pkg/orm/filter/opentracing/filter.go index a55ae6d2..405e39ea 100644 --- a/pkg/orm/filter/opentracing/filter.go +++ b/pkg/orm/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package opentracing import ( "context" + "strings" "github.com/opentracing/opentracing-go" @@ -27,6 +28,8 @@ import ( // for example: // if we want to trace QuerySetter // actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// the method Begin*, Commit and Rollback are ignored. +// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { // CustomSpanFunc users are able to custom their span CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) @@ -35,25 +38,34 @@ type FilterChainBuilder struct { func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { return func(ctx context.Context, inv *orm.Invocation) { operationName := builder.operationName(ctx, inv) - span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) - defer span.Finish() - - next(spanCtx, inv) - span.SetTag("Method", inv.Method) - span.SetTag("Table", inv.GetTableName()) - span.SetTag("InsideTx", inv.InsideTx) - span.SetTag("TxName", spanCtx.Value(orm.TxNameKey)) - - if builder.CustomSpanFunc != nil { - builder.CustomSpanFunc(span, spanCtx, inv) + if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" { + next(ctx, inv) + return } + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + next(spanCtx, inv) + builder.buildSpan(span, spanCtx, inv) + } +} + +func (builder *FilterChainBuilder) buildSpan(span opentracing.Span, ctx context.Context, inv *orm.Invocation) { + span.SetTag("orm.method", inv.Method) + span.SetTag("orm.table", inv.GetTableName()) + span.SetTag("orm.insideTx", inv.InsideTx) + span.SetTag("orm.txName", ctx.Value(orm.TxNameKey)) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx, inv) } } func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string { if n, ok := ctx.Value(orm.TxNameKey).(string); ok { - return inv.Method + "#" + n + return inv.Method + "#tx(" + n + ")" } return inv.Method + "#" + inv.GetTableName() } diff --git a/pkg/orm/filter/opentracing/filter_test.go b/pkg/orm/filter/opentracing/filter_test.go index 1428df8a..7df12a92 100644 --- a/pkg/orm/filter/opentracing/filter_test.go +++ b/pkg/orm/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -40,4 +40,4 @@ func TestFilterChainBuilder_FilterChain(t *testing.T) { TxStartTime: time.Now(), } builder.FilterChain(next)(context.Background(), inv) -} \ No newline at end of file +} diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/orm/filter/prometheus/filter.go index 33fdf78f..2e67d85c 100644 --- a/pkg/orm/filter/prometheus/filter.go +++ b/pkg/orm/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter/prometheus/filter_test.go b/pkg/orm/filter/prometheus/filter_test.go index a71e8f50..34766fb4 100644 --- a/pkg/orm/filter/prometheus/filter_test.go +++ b/pkg/orm/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/orm/filter_orm_decorator.go index eb26ea68..279d299f 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/orm/filter_orm_decorator.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,12 @@ import ( "time" ) -const TxNameKey = "TxName" +const ( + TxNameKey = "TxName" +) + +var _ Ormer = new(filterOrmDecorator) +var _ TxOrmer = new(filterOrmDecorator) type filterOrmDecorator struct { ormer @@ -40,7 +45,7 @@ func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { ormer: delegate, TxBeginner: delegate, root: func(ctx context.Context, inv *Invocation) { - inv.execute() + inv.execute(ctx) }, } @@ -58,7 +63,7 @@ func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrm root: root, insideTx: true, txStartTime: time.Now(), - txName: txName, + txName: txName, } return res } @@ -76,8 +81,8 @@ func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, co mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - err = f.ormer.ReadWithCtx(ctx, md, cols...) + f: func(c context.Context) { + err = f.ormer.ReadWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -98,8 +103,8 @@ func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interf mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - err = f.ormer.ReadForUpdateWithCtx(ctx, md, cols...) + f: func(c context.Context) { + err = f.ormer.ReadForUpdateWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -125,8 +130,8 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - ok, res, err = f.ormer.ReadOrCreateWithCtx(ctx, md, col1, cols...) + f: func(c context.Context) { + ok, res, err = f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...) }, } f.root(ctx, inv) @@ -151,8 +156,8 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.LoadRelatedWithCtx(ctx, md, name, args...) + f: func(c context.Context) { + res, err = f.ormer.LoadRelatedWithCtx(c, md, name, args...) }, } f.root(ctx, inv) @@ -176,8 +181,8 @@ func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{} mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res = f.ormer.QueryM2MWithCtx(ctx, md, name) + f: func(c context.Context) { + res = f.ormer.QueryM2MWithCtx(c, md, name) }, } f.root(ctx, inv) @@ -190,10 +195,10 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { var ( - res QuerySeter + res QuerySeter name string - md interface{} - mi *modelInfo + md interface{} + mi *modelInfo ) if table, ok := ptrStructOrTableName.(string); ok { @@ -212,10 +217,10 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - Md: md, - mi: mi, - f: func() { - res = f.ormer.QueryTableWithCtx(ctx, ptrStructOrTableName) + Md: md, + mi: mi, + f: func(c context.Context) { + res = f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) }, } f.root(ctx, inv) @@ -230,7 +235,7 @@ func (f *filterOrmDecorator) DBStats() *sql.DBStats { Method: "DBStats", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { + f: func(c context.Context) { res = f.ormer.DBStats() }, } @@ -255,8 +260,8 @@ func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertWithCtx(ctx, md) + f: func(c context.Context) { + res, err = f.ormer.InsertWithCtx(c, md) }, } f.root(ctx, inv) @@ -280,8 +285,8 @@ func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md inter mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertOrUpdateWithCtx(ctx, md, colConflitAndArgs...) + f: func(c context.Context) { + res, err = f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...) }, } f.root(ctx, inv) @@ -316,8 +321,8 @@ func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, m mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.InsertMultiWithCtx(ctx, bulk, mds) + f: func(c context.Context) { + res, err = f.ormer.InsertMultiWithCtx(c, bulk, mds) }, } f.root(ctx, inv) @@ -341,8 +346,8 @@ func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.UpdateWithCtx(ctx, md, cols...) + f: func(c context.Context) { + res, err = f.ormer.UpdateWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -366,8 +371,8 @@ func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.ormer.DeleteWithCtx(ctx, md, cols...) + f: func(c context.Context) { + res, err = f.ormer.DeleteWithCtx(c, md, cols...) }, } f.root(ctx, inv) @@ -387,8 +392,8 @@ func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args Args: []interface{}{query, args}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res = f.ormer.RawWithCtx(ctx, query, args...) + f: func(c context.Context) { + res = f.ormer.RawWithCtx(c, query, args...) }, } f.root(ctx, inv) @@ -403,7 +408,7 @@ func (f *filterOrmDecorator) Driver() Driver { Method: "Driver", InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { + f: func(c context.Context) { res = f.ormer.Driver() }, } @@ -433,28 +438,28 @@ func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql. Args: []interface{}{opts}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, - f: func() { - res, err = f.TxBeginner.BeginWithCtxAndOpts(ctx, opts) - res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(ctx)) + f: func(c context.Context) { + res, err = f.TxBeginner.BeginWithCtxAndOpts(c, opts) + res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c)) }, } f.root(ctx, inv) return res, err } -func (f *filterOrmDecorator) DoTx(task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(context.Background(), nil, task) } -func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(ctx, nil, task) } -func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return f.DoTxWithCtxAndOpts(context.Background(), opts, task) } -func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { var ( err error ) @@ -465,8 +470,8 @@ func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.T InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: getTxNameFromCtx(ctx), - f: func() { - err = f.TxBeginner.DoTxWithCtxAndOpts(ctx, opts, task) + f: func(c context.Context) { + err = doTxTemplate(f, c, opts, task) }, } f.root(ctx, inv) @@ -483,7 +488,7 @@ func (f *filterOrmDecorator) Commit() error { InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func() { + f: func(c context.Context) { err = f.TxCommitter.Commit() }, } @@ -501,7 +506,7 @@ func (f *filterOrmDecorator) Rollback() error { InsideTx: f.insideTx, TxStartTime: f.txStartTime, TxName: f.txName, - f: func() { + f: func(c context.Context) { err = f.TxCommitter.Rollback() }, } @@ -516,4 +521,3 @@ func getTxNameFromCtx(ctx context.Context) string { } return txName } - diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/orm/filter_orm_decorator_test.go index d1099eaf..4e837a4e 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/orm/filter_orm_decorator_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,49 +130,49 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) { - assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) - assert.Equal(t, 2, len(inv.Args)) - assert.Equal(t, "", inv.GetTableName()) - assert.False(t, inv.InsideTx) + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + } + next(ctx, inv) } }) - err := od.DoTx(func(txOrm TxOrmer) error { - return errors.New("tx error") + err := od.DoTx(func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx error", err.Error()) - err = od.DoTxWithCtx(context.Background(), func(txOrm TxOrmer) error { - return errors.New("tx ctx error") + err = od.DoTxWithCtx(context.Background(), func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx ctx error", err.Error()) - err = od.DoTxWithOpts(nil, func(txOrm TxOrmer) error { - return errors.New("tx opts error") + err = od.DoTxWithOpts(nil, func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx opts error", err.Error()) + od = NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) { - assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) - assert.Equal(t, 2, len(inv.Args)) - assert.Equal(t, "", inv.GetTableName()) - assert.Equal(t, "do tx name", inv.TxName) - assert.False(t, inv.InsideTx) + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.Equal(t, "do tx name", inv.TxName) + assert.False(t, inv.InsideTx) + } next(ctx, inv) } }) ctx := context.WithValue(context.Background(), TxNameKey, "do tx name") - err = od.DoTxWithCtxAndOpts(ctx, nil, func(txOrm TxOrmer) error { - return errors.New("tx ctx opts error") + err = od.DoTxWithCtxAndOpts(ctx, nil, func(c context.Context, txOrm TxOrmer) error { + return nil }) assert.NotNil(t, err) - assert.Equal(t, "tx ctx opts error", err.Error()) } func TestFilterOrmDecorator_Driver(t *testing.T) { @@ -347,6 +347,8 @@ func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) { assert.Equal(t, int64(13), i) } +var _ Ormer = new(filterMockOrm) + // filterMockOrm is only used in this test file type filterMockOrm struct { DoNothingOrm @@ -376,8 +378,8 @@ func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int6 return 100, errors.New("insert error") } -func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { - return task(nil) +func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(c context.Context, txOrm TxOrmer) error) error { + return task(ctx, nil) } func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { diff --git a/pkg/orm/filter_test.go b/pkg/orm/filter_test.go index 0f2944c7..b2ca4ae1 100644 --- a/pkg/orm/filter_test.go +++ b/pkg/orm/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,4 +28,5 @@ func TestAddGlobalFilterChain(t *testing.T) { } }) assert.Equal(t, 1, len(globalFilterChains)) + globalFilterChains = nil } diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go index 1c9fee09..e935b7ea 100644 --- a/pkg/orm/invocation.go +++ b/pkg/orm/invocation.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ package orm import ( + "context" "time" ) @@ -22,27 +23,27 @@ import ( type Invocation struct { Method string // Md may be nil in some cases. It depends on method - Md interface{} + Md interface{} // the args are all arguments except context.Context - Args []interface{} + Args []interface{} mi *modelInfo // f is the Orm operation - f func() + f func(ctx context.Context) // insideTx indicates whether this is inside a transaction - InsideTx bool + InsideTx bool TxStartTime time.Time - TxName string + TxName string } func (inv *Invocation) GetTableName() string { - if inv.mi != nil{ + if inv.mi != nil { return inv.mi.table } return "" } -func (inv *Invocation) execute() { - inv.f() +func (inv *Invocation) execute(ctx context.Context) { + inv.f(ctx) } diff --git a/pkg/orm/model_utils_test.go b/pkg/orm/model_utils_test.go index ea38d90a..b65aadcb 100644 --- a/pkg/orm/model_utils_test.go +++ b/pkg/orm/model_utils_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index ae166dc7..09ef4f15 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -27,7 +27,6 @@ import ( _ "github.com/mattn/go-sqlite3" // As tidb can't use go get, so disable the tidb testing now // _ "github.com/pingcap/tidb" - ) // A slice string field. diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index d79053af..cc678fc8 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -522,19 +522,24 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO return taskTxOrm, nil } -func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { +func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtx(context.Background(), task) } -func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtxAndOpts(ctx, nil, task) } -func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { return o.DoTxWithCtxAndOpts(context.Background(), opts, task) } -func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return doTxTemplate(o, ctx, opts, task) +} + +func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions, + task func(ctx context.Context, txOrm TxOrmer) error) error { _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) if err != nil { return err @@ -553,9 +558,8 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task } } }() - var taskTxOrm = _txOrm - err = task(taskTxOrm) + err = task(ctx, taskTxOrm) panicked = false return err } @@ -582,18 +586,11 @@ func NewOrm() Ormer { // NewOrmUsingDB create new orm with the name func NewOrmUsingDB(aliasName string) Ormer { - o := new(orm) if al, ok := dataBaseCache.get(aliasName); ok { - o.alias = al - if Debug { - o.db = newDbQueryLog(al, al.DB) - } else { - o.db = al.DB - } + return newDBWithAlias(al) } else { panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } - return o } // NewOrmWithDB create a new ormer object with specify *sql.DB for query @@ -603,14 +600,21 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) return nil, err } + return newDBWithAlias(al), nil +} + +func newDBWithAlias(al *alias) Ormer { o := new(orm) o.alias = al if Debug { - o.db = newDbQueryLog(o.alias, db) + o.db = newDbQueryLog(al, al.DB) } else { - o.db = db + o.db = al.DB } - return o, nil + if len(globalFilterChains) > 0 { + return NewFilterOrmDecorator(o, globalFilterChains...) + } + return o } diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index f5242a46..e3dafecd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2486,5 +2486,3 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) } } - - diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 9624fd94..59688588 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -95,10 +95,10 @@ type TxBeginner interface { BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) //closure control transaction - DoTx(task func(txOrm TxOrmer) error) error - DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error - DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error - DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error + DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error } type TxCommitter interface { diff --git a/pkg/router.go b/pkg/router.go index 8caba94a..6b25d7e3 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -468,12 +468,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { root := p.chainRoot filterFunc := chain(root.filterFunc) p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) -} + p.chainRoot.next = root +} // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go index 64b8d94a..a27d30a6 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/session/sess_file_test.go @@ -57,7 +57,7 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) exists, err := fp.SessionExist(sid) - if err != nil{ + if err != nil { t.Error(err) } if exists { diff --git a/pkg/web/doc.go b/pkg/web/doc.go index 2001f4ca..1425a729 100644 --- a/pkg/web/doc.go +++ b/pkg/web/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/web/filter/opentracing/filter.go index 822d5e4d..e6ee9150 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/web/filter/opentracing/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ type FilterChainBuilder struct { CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) } - func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { return func(ctx *beegoCtx.Context) { var ( @@ -55,9 +54,21 @@ func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.Filt next(ctx) // if you think we need to do more things, feel free to create an issue to tell us - span.SetTag("status", ctx.Output.Status) - span.SetTag("method", ctx.Input.Method()) - span.SetTag("route", ctx.Input.GetData("RouterPattern")) + span.SetTag("http.status_code", ctx.ResponseWriter.Status) + span.SetTag("http.method", ctx.Input.Method()) + span.SetTag("peer.hostname", ctx.Request.Host) + span.SetTag("http.url", ctx.Request.URL.String()) + span.SetTag("http.scheme", ctx.Request.URL.Scheme) + span.SetTag("span.kind", "server") + span.SetTag("component", "beego") + if ctx.Output.IsServerError() || ctx.Output.IsClientError() { + span.SetTag("error", true) + } + span.SetTag("peer.address", ctx.Request.RemoteAddr) + span.SetTag("http.proto", ctx.Request.Proto) + + span.SetTag("beego.route", ctx.Input.GetData("RouterPattern")) + if builder.CustomSpanFunc != nil { builder.CustomSpanFunc(span, ctx) } @@ -70,7 +81,7 @@ func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string { // TODO, if we support multiple servers, this need to be changed route, found := beego.BeeApp.Handlers.FindRouter(ctx) if found { - operationName = route.GetPattern() + operationName = ctx.Input.Method() + "#" + route.GetPattern() } return operationName } diff --git a/pkg/web/filter/opentracing/filter_test.go b/pkg/web/filter/opentracing/filter_test.go index 65f1f24e..750ea7a9 100644 --- a/pkg/web/filter/opentracing/filter_test.go +++ b/pkg/web/filter/opentracing/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/prometheus/filter.go b/pkg/web/filter/prometheus/filter.go index bd47dcec..8f4b46e3 100644 --- a/pkg/web/filter/prometheus/filter.go +++ b/pkg/web/filter/prometheus/filter.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/web/filter/prometheus/filter_test.go b/pkg/web/filter/prometheus/filter_test.go index 7d2e2acf..822892bc 100644 --- a/pkg/web/filter/prometheus/filter_test.go +++ b/pkg/web/filter/prometheus/filter_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From ce698aacf6a5b7d2ebe3ccc25204b65cc0fdeeb3 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 11 Aug 2020 12:06:02 +0800 Subject: [PATCH 5/5] rm some methods --- pkg/common/kv.go | 19 ------------------- pkg/common/kv_test.go | 10 ++++------ 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/pkg/common/kv.go b/pkg/common/kv.go index 26e786f9..80797aa9 100644 --- a/pkg/common/kv.go +++ b/pkg/common/kv.go @@ -41,8 +41,6 @@ type KVs interface { GetValueOr(key interface{}, defValue interface{}) interface{} Contains(key interface{}) bool IfContains(key interface{}, action func(value interface{})) KVs - Put(key interface{}, value interface{}) KVs - Clone() KVs } // SimpleKVs will store SimpleKV collection as map @@ -77,23 +75,6 @@ func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{}) return kvs } -// Put stores the value -func (kvs *SimpleKVs) Put(key interface{}, value interface{}) KVs { - kvs.kvs[key] = value - return kvs -} - -// Clone -func (kvs *SimpleKVs) Clone() KVs { - newKVs := new(SimpleKVs) - - for key, value := range kvs.kvs { - newKVs.Put(key, value) - } - - return newKVs -} - // NewKVs creates the *KVs instance func NewKVs(kvs ...KV) KVs { res := &SimpleKVs{ diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index 275c6753..7b52a300 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -29,12 +29,10 @@ func TestKVs(t *testing.T) { assert.True(t, kvs.Contains(key)) - kvs.IfContains(key, func(value interface{}) { - kvs.Put("my-key1", "") - }) - - assert.True(t, kvs.Contains("my-key1")) - v := kvs.GetValueOr(key, 13) assert.Equal(t, 12, v) + + v = kvs.GetValueOr(`key-not-exists`, 8546) + assert.Equal(t, 8546, v) + }