1
0
mirror of https://github.com/astaxie/beego.git synced 2024-06-01 23:53:28 +00:00

Merge pull request #4147 from jianzhiyao/frt/specify_index_2

specify index
This commit is contained in:
Ming Deng 2020-08-11 16:13:13 +08:00 committed by GitHub
commit a1b7fd3c93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 526 additions and 200 deletions

View File

@ -36,14 +36,23 @@ func (s *SimpleKV) GetValue() interface{} {
return s.Value return s.Value
} }
// KVs will store SimpleKV collection as map // KVs interface
type KVs struct { type KVs interface {
GetValueOr(key interface{}, defValue interface{}) interface{}
Contains(key interface{}) bool
IfContains(key interface{}, action func(value interface{})) KVs
}
// SimpleKVs will store SimpleKV collection as map
type SimpleKVs struct {
kvs map[interface{}]interface{} kvs map[interface{}]interface{}
} }
var _ KVs = new(SimpleKVs)
// GetValueOr returns the value for a given key, if non-existant // GetValueOr returns the value for a given key, if non-existant
// it returns defValue // it returns defValue
func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
v, ok := kvs.kvs[key] v, ok := kvs.kvs[key]
if ok { if ok {
return v return v
@ -52,13 +61,13 @@ func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
} }
// Contains checks if a key exists // Contains checks if a key exists
func (kvs *KVs) Contains(key interface{}) bool { func (kvs *SimpleKVs) Contains(key interface{}) bool {
_, ok := kvs.kvs[key] _, ok := kvs.kvs[key]
return ok return ok
} }
// IfContains invokes the action on a key if it exists // IfContains invokes the action on a key if it exists
func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs {
v, ok := kvs.kvs[key] v, ok := kvs.kvs[key]
if ok { if ok {
action(v) action(v)
@ -66,15 +75,9 @@ func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs
return kvs return kvs
} }
// Put stores the value
func (kvs *KVs) Put(key interface{}, value interface{}) *KVs {
kvs.kvs[key] = value
return kvs
}
// NewKVs creates the *KVs instance // NewKVs creates the *KVs instance
func NewKVs(kvs ...KV) *KVs { func NewKVs(kvs ...KV) KVs {
res := &KVs{ res := &SimpleKVs{
kvs: make(map[interface{}]interface{}, len(kvs)), kvs: make(map[interface{}]interface{}, len(kvs)),
} }
for _, kv := range kvs { for _, kv := range kvs {

View File

@ -29,12 +29,10 @@ func TestKVs(t *testing.T) {
assert.True(t, kvs.Contains(key)) assert.True(t, kvs.Contains(key))
kvs.IfContains(key, func(value interface{}) {
kvs.Put("my-key1", "")
})
assert.True(t, kvs.Contains("my-key1"))
v := kvs.GetValueOr(key, 13) v := kvs.GetValueOr(key, 13)
assert.Equal(t, 12, v) assert.Equal(t, 12, v)
v = kvs.GetValueOr(`key-not-exists`, 8546)
assert.Equal(t, 8546, v)
} }

View File

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -738,8 +739,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
} }
where, args := tables.getCondSQL(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
@ -790,9 +793,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
sets := strings.Join(cols, ", ") + " " sets := strings.Join(cols, ", ") + " "
if d.ins.SupportUpdateJoin() { if d.ins.SupportUpdateJoin() {
query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where)
} else { } else {
supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s",
Q, mi.fields.pk.column, Q,
Q, mi.table, Q,
specifyIndexes, join, where)
query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery)
} }
@ -843,8 +849,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true tables.skipEnd = true
var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
} }
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
@ -857,7 +865,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
join := tables.getJoinSQL() join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1002,6 +1010,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit) limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
if tbl.sel { if tbl.sel {
@ -1015,9 +1024,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if qs.distinct { if qs.distinct {
sqlSelect += " DISTINCT" sqlSelect += " DISTINCT"
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
if qs.forupdate { if qs.forUpdate {
query += " FOR UPDATE" query += " FOR UPDATE"
} }
@ -1153,10 +1164,13 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
groupBy := tables.getGroupSQL(qs.groups) groupBy := tables.getGroupSQL(qs.groups)
tables.getOrderSQL(qs.orders) tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
Q, mi.table, Q,
specifyIndexes, join, where, groupBy)
if groupBy != "" { if groupBy != "" {
query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
@ -1680,6 +1694,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit) limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
@ -1687,7 +1702,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
if qs.distinct { if qs.distinct {
sqlSelect += " DISTINCT" sqlSelect += " DISTINCT"
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels,
Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1781,10 +1799,6 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil return cnt, nil
} }
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
return 0, nil
}
// flag of update joined record. // flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool { func (d *dbBase) SupportUpdateJoin() bool {
return true return true
@ -1900,3 +1914,31 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
var useWay string
switch useIndex {
case hints.KeyUseIndex:
useWay = `USE`
case hints.KeyForceIndex:
useWay = `FORCE`
case hints.KeyIgnoreIndex:
useWay = `IGNORE`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`))
}

View File

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"sync" "sync"
"time" "time"
@ -363,7 +364,7 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K
var stmtCache *lru.Cache var stmtCache *lru.Cache
var stmtCacheSize int var stmtCacheSize int
maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int)
if maxStmtCacheSize > 0 { if maxStmtCacheSize > 0 {
_stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize)
if errC != nil { if errC != nil {
@ -398,15 +399,15 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K
detectTZ(al) detectTZ(al)
kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) {
if m, ok := value.(int); ok { if m, ok := value.(int); ok {
SetMaxIdleConns(al, m) SetMaxIdleConns(al, m)
} }
}).IfContains(maxOpenConnectionsKey, func(value interface{}) { }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) {
if m, ok := value.(int); ok { if m, ok := value.(int); ok {
SetMaxOpenConns(al, m) SetMaxOpenConns(al, m)
} }
}).IfContains(connMaxLifetimeKey, func(value interface{}) { }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) {
if m, ok := value.(time.Duration); ok { if m, ok := value.(time.Duration); ok {
SetConnMaxLifetime(al, m) SetConnMaxLifetime(al, m)
} }
@ -422,7 +423,7 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV
} }
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error {
var ( var (
err error err error
db *sql.DB db *sql.DB
@ -436,7 +437,7 @@ func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.
goto end goto end
} }
al, err = addAliasWthDB(aliasName, driverName, db, hints...) al, err = addAliasWthDB(aliasName, driverName, db, params...)
if err != nil { if err != nil {
goto end goto end
} }

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"github.com/astaxie/beego/pkg/orm/hints"
"testing" "testing"
"time" "time"
@ -23,9 +24,9 @@ import (
func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase(t *testing.T) {
err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source,
MaxIdleConnections(20), hints.MaxIdleConnections(20),
MaxOpenConnections(300), hints.MaxOpenConnections(300),
ConnMaxLifetime(time.Minute)) hints.ConnMaxLifetime(time.Minute))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias("test-params") al := getDbAlias("test-params")
@ -37,7 +38,7 @@ func TestRegisterDataBase(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -47,7 +48,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" aliasName := "TestRegisterDataBase_MaxStmtCacheSize0"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -57,7 +58,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" aliasName := "TestRegisterDataBase_MaxStmtCacheSize1"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -67,7 +68,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" aliasName := "TestRegisterDataBase_MaxStmtCacheSize841"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)

View File

@ -1,76 +0,0 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewHint_time(t *testing.T) {
key := "qweqwe"
value := time.Second
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_int(t *testing.T) {
key := "qweqwe"
value := 281230
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_float(t *testing.T) {
key := "qweqwe"
value := 21.2459753
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestMaxOpenConnections(t *testing.T) {
i := 887423
hint := MaxOpenConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey)
}
func TestConnMaxLifetime(t *testing.T) {
i := time.Hour
hint := ConnMaxLifetime(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), connMaxLifetimeKey)
}
func TestMaxIdleConnections(t *testing.T) {
i := 42316
hint := MaxIdleConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey)
}
func TestMaxStmtCacheSize(t *testing.T) {
i := 94157
hint := MaxStmtCacheSize(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey)
}

View File

@ -16,6 +16,7 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"strings" "strings"
) )
@ -96,6 +97,29 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0 return cnt > 0
} }
func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
var hint string
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
hint = `INDEX`
case hints.KeyIgnoreIndex:
hint = `NO_INDEX`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
}
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {

View File

@ -92,6 +92,7 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0 return 0
} }
// postgresql quote is ". // postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string { func (d *dbBasePostgres) TableQuote() string {
return `"` return `"`
@ -181,6 +182,12 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0 return cnt > 0
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
return ``
}
// create new postgresql dbBaser. // create new postgresql dbBaser.
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)

View File

@ -17,7 +17,9 @@ package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"reflect" "reflect"
"strings"
"time" "time"
) )
@ -153,6 +155,25 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false return false
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
}
// create new sqlite dbBaser. // create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)

View File

@ -472,6 +472,15 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits
return return
} }
// getIndexSql generate index sql.
func (t *dbTables) getIndexSql(tableName string,useIndex int, indexes []string) (clause string) {
if len(indexes) == 0 {
return
}
return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes)
}
// crete new tables collection. // crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
) )
// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation
@ -52,11 +53,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{},
return false, 0, nil return false, 0, nil
} }
func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return 0, nil return 0, nil
} }
func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
return 0, nil return 0, nil
} }

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
"reflect" "reflect"
"time" "time"
) )
@ -133,11 +134,11 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa
return ok, res, err return ok, res, err
} }
func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return f.LoadRelatedWithCtx(context.Background(), md, name, args...) return f.LoadRelatedWithCtx(context.Background(), md, name, args...)
} }
func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
var ( var (
res int64 res int64
err error err error

View File

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/astaxie/beego/pkg/common"
"sync" "sync"
"testing" "testing"
@ -360,7 +361,7 @@ func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}
return errors.New("read for update error") return errors.New("read for update error")
} }
func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
return 99, errors.New("load related error") return 99, errors.New("load related error")
} }

View File

@ -12,13 +12,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package hints
import ( import (
"github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/common"
"time" "time"
) )
const (
//db level
KeyMaxIdleConnections = iota
KeyMaxOpenConnections
KeyConnMaxLifetime
KeyMaxStmtCacheSize
//query level
KeyForceIndex
KeyUseIndex
KeyIgnoreIndex
KeyForUpdate
KeyLimit
KeyOffset
KeyOrderBy
KeyRelDepth
)
type Hint struct { type Hint struct {
key interface{} key interface{}
value interface{} value interface{}
@ -36,33 +54,71 @@ func (s *Hint) GetValue() interface{} {
return s.value return s.value
} }
const (
maxIdleConnectionsKey = "MaxIdleConnections"
maxOpenConnectionsKey = "MaxOpenConnections"
connMaxLifetimeKey = "ConnMaxLifetime"
maxStmtCacheSizeKey = "MaxStmtCacheSize"
)
var _ common.KV = new(Hint) var _ common.KV = new(Hint)
// MaxIdleConnections return a hint about MaxIdleConnections // MaxIdleConnections return a hint about MaxIdleConnections
func MaxIdleConnections(v int) *Hint { func MaxIdleConnections(v int) *Hint {
return NewHint(maxIdleConnectionsKey, v) return NewHint(KeyMaxIdleConnections, v)
} }
// MaxOpenConnections return a hint about MaxOpenConnections // MaxOpenConnections return a hint about MaxOpenConnections
func MaxOpenConnections(v int) *Hint { func MaxOpenConnections(v int) *Hint {
return NewHint(maxOpenConnectionsKey, v) return NewHint(KeyMaxOpenConnections, v)
} }
// ConnMaxLifetime return a hint about ConnMaxLifetime // ConnMaxLifetime return a hint about ConnMaxLifetime
func ConnMaxLifetime(v time.Duration) *Hint { func ConnMaxLifetime(v time.Duration) *Hint {
return NewHint(connMaxLifetimeKey, v) return NewHint(KeyConnMaxLifetime, v)
} }
// MaxStmtCacheSize return a hint about MaxStmtCacheSize // MaxStmtCacheSize return a hint about MaxStmtCacheSize
func MaxStmtCacheSize(v int) *Hint { func MaxStmtCacheSize(v int) *Hint {
return NewHint(maxStmtCacheSizeKey, v) return NewHint(KeyMaxStmtCacheSize, v)
}
// ForceIndex return a hint about ForceIndex
func ForceIndex(indexes ...string) *Hint {
return NewHint(KeyForceIndex, indexes)
}
// UseIndex return a hint about UseIndex
func UseIndex(indexes ...string) *Hint {
return NewHint(KeyUseIndex, indexes)
}
// IgnoreIndex return a hint about IgnoreIndex
func IgnoreIndex(indexes ...string) *Hint {
return NewHint(KeyIgnoreIndex, indexes)
}
// ForUpdate return a hint about ForUpdate
func ForUpdate() *Hint {
return NewHint(KeyForUpdate, true)
}
// DefaultRelDepth return a hint about DefaultRelDepth
func DefaultRelDepth() *Hint {
return NewHint(KeyRelDepth, true)
}
// RelDepth return a hint about RelDepth
func RelDepth(d int) *Hint {
return NewHint(KeyRelDepth, d)
}
// Limit return a hint about Limit
func Limit(d int64) *Hint {
return NewHint(KeyLimit, d)
}
// Offset return a hint about Offset
func Offset(d int64) *Hint {
return NewHint(KeyOffset, d)
}
// OrderBy return a hint about OrderBy
func OrderBy(s string) *Hint {
return NewHint(KeyOrderBy, s)
} }
// NewHint return a hint // NewHint return a hint

View File

@ -0,0 +1,154 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package hints
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewHint_time(t *testing.T) {
key := "qweqwe"
value := time.Second
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_int(t *testing.T) {
key := "qweqwe"
value := 281230
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_float(t *testing.T) {
key := "qweqwe"
value := 21.2459753
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestMaxOpenConnections(t *testing.T) {
i := 887423
hint := MaxOpenConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections)
}
func TestConnMaxLifetime(t *testing.T) {
i := time.Hour
hint := ConnMaxLifetime(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime)
}
func TestMaxIdleConnections(t *testing.T) {
i := 42316
hint := MaxIdleConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections)
}
func TestMaxStmtCacheSize(t *testing.T) {
i := 94157
hint := MaxStmtCacheSize(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize)
}
func TestForceIndex(t *testing.T) {
s := []string{`f_index1`, `f_index2`, `f_index3`}
hint := ForceIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyForceIndex)
}
func TestForceIndex_0(t *testing.T) {
var s []string
hint := ForceIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyForceIndex)
}
func TestIgnoreIndex(t *testing.T) {
s := []string{`i_index1`, `i_index2`, `i_index3`}
hint := IgnoreIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
}
func TestIgnoreIndex_0(t *testing.T) {
var s []string
hint := IgnoreIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
}
func TestUseIndex(t *testing.T) {
s := []string{`u_index1`, `u_index2`, `u_index3`}
hint := UseIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyUseIndex)
}
func TestUseIndex_0(t *testing.T) {
var s []string
hint := UseIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyUseIndex)
}
func TestForUpdate(t *testing.T) {
hint := ForUpdate()
assert.Equal(t, hint.GetValue(), true)
assert.Equal(t, hint.GetKey(), KeyForUpdate)
}
func TestDefaultRelDepth(t *testing.T) {
hint := DefaultRelDepth()
assert.Equal(t, hint.GetValue(), true)
assert.Equal(t, hint.GetKey(), KeyRelDepth)
}
func TestRelDepth(t *testing.T) {
hint := RelDepth(157965)
assert.Equal(t, hint.GetValue(), 157965)
assert.Equal(t, hint.GetKey(), KeyRelDepth)
}
func TestLimit(t *testing.T) {
hint := Limit(1579625)
assert.Equal(t, hint.GetValue(), int64(1579625))
assert.Equal(t, hint.GetKey(), KeyLimit)
}
func TestOffset(t *testing.T) {
hint := Offset(int64(1572123965))
assert.Equal(t, hint.GetValue(), int64(1572123965))
assert.Equal(t, hint.GetKey(), KeyOffset)
}
func TestOrderBy(t *testing.T) {
hint := OrderBy(`-ID`)
assert.Equal(t, hint.GetValue(), `-ID`)
assert.Equal(t, hint.GetKey(), KeyOrderBy)
}

View File

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"os" "os"
"strings" "strings"
"time" "time"
@ -381,6 +382,15 @@ type InLine struct {
Email string Email string
} }
type Index struct {
// Common Fields
Id int `orm:"column(id)"`
// Other Fields
F1 int `orm:"column(f1);index"`
F2 int `orm:"column(f2);index"`
}
func NewInLine() *InLine { func NewInLine() *InLine {
return new(InLine) return new(InLine)
} }
@ -488,7 +498,7 @@ func init() {
os.Exit(2) os.Exit(2)
} }
err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20))
if err != nil { if err != nil {
panic(fmt.Sprintf("can not register database: %v", err)) panic(fmt.Sprintf("can not register database: %v", err))

View File

@ -59,6 +59,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/common"
"github.com/astaxie/beego/pkg/orm/hints"
"os" "os"
"reflect" "reflect"
"time" "time"
@ -99,6 +100,7 @@ type ormBase struct {
var _ DQL = new(ormBase) var _ DQL = new(ormBase)
var _ DML = new(ormBase) var _ DML = new(ormBase)
var _ DriverGetter = new(ormBase)
// get model info and model reflect value // get model info and model reflect value
func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
@ -302,11 +304,10 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
// for _,tag := range post.Tags{...} // for _,tag := range post.Tags{...}
// //
// make sure the relation is defined in model struct tags. // make sure the relation is defined in model struct tags.
func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (o *ormBase) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return o.LoadRelatedWithCtx(context.Background(), md, name, args...) return o.LoadRelatedWithCtx(context.Background(), md, name, args...)
} }
func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name) _, fi, ind, qseter := o.queryRelated(md, name)
qs := qseter.(*querySet) qs := qseter.(*querySet)
@ -314,24 +315,29 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
var relDepth int var relDepth int
var limit, offset int64 var limit, offset int64
var order string var order string
for i, arg := range args {
switch i { kvs := common.NewKVs(args...)
case 0: kvs.IfContains(hints.KeyRelDepth, func(value interface{}) {
if v, ok := arg.(bool); ok { if v, ok := value.(bool); ok {
if v { if v {
relDepth = DefaultRelsDepth relDepth = DefaultRelsDepth
}
} else if v, ok := arg.(int); ok {
relDepth = v
} }
case 1: } else if v, ok := value.(int); ok {
limit = ToInt64(arg) relDepth = v
case 2:
offset = ToInt64(arg)
case 3:
order, _ = arg.(string)
} }
} }).IfContains(hints.KeyLimit, func(value interface{}) {
if v, ok := value.(int64); ok {
limit = v
}
}).IfContains(hints.KeyOffset, func(value interface{}) {
if v, ok := value.(int64); ok {
offset = v
}
}).IfContains(hints.KeyOrderBy, func(value interface{}) {
if v, ok := value.(string); ok {
order = v
}
})
switch fi.fieldType { switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne: case RelOneToOne, RelForeignKey, RelReverseOne:

View File

@ -127,10 +127,7 @@ var _ txer = new(dbQueryLog)
var _ txEnder = new(dbQueryLog) var _ txEnder = new(dbQueryLog)
func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
a := time.Now() return d.PrepareContext(context.Background(), query)
stmt, err := d.db.Prepare(query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
} }
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
@ -141,10 +138,7 @@ func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stm
} }
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now() return d.ExecContext(context.Background(), query, args...)
res, err := d.db.Exec(query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
} }
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
@ -155,10 +149,7 @@ func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...inte
} }
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now() return d.QueryContext(context.Background(), query, args...)
res, err := d.db.Query(query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
} }
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
@ -169,10 +160,7 @@ func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...int
} }
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now() return d.QueryRowContext(context.Background(), query, args...)
res := d.db.QueryRow(query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
} }
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
@ -183,10 +171,7 @@ func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...
} }
func (d *dbQueryLog) Begin() (*sql.Tx, error) { func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now() return d.BeginTx(context.Background(), nil)
tx, err := d.db.(txer).Begin()
debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err)
return tx, err
} }
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
) )
type colValue struct { type colValue struct {
@ -71,7 +72,9 @@ type querySet struct {
groups []string groups []string
orders []string orders []string
distinct bool distinct bool
forupdate bool forUpdate bool
useIndex int
indexes []string
orm *ormBase orm *ormBase
ctx context.Context ctx context.Context
forContext bool forContext bool
@ -148,7 +151,28 @@ func (o querySet) Distinct() QuerySeter {
// add FOR UPDATE to SELECT // add FOR UPDATE to SELECT
func (o querySet) ForUpdate() QuerySeter { func (o querySet) ForUpdate() QuerySeter {
o.forupdate = true o.forUpdate = true
return &o
}
// ForceIndex force index for query
func (o querySet) ForceIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyForceIndex
o.indexes = indexes
return &o
}
// UseIndex use index for query
func (o querySet) UseIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyUseIndex
o.indexes = indexes
return &o
}
// IgnoreIndex ignore index for query
func (o querySet) IgnoreIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyIgnoreIndex
o.indexes = indexes
return &o return &o
} }

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"io/ioutil" "io/ioutil"
"math" "math"
"os" "os"
@ -200,6 +201,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk)) RegisterModel(new(PtrPk))
RegisterModel(new(Index))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
@ -224,6 +226,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk)) RegisterModel(new(PtrPk))
RegisterModel(new(Index))
BootStrap() BootStrap()
@ -793,6 +796,32 @@ func TestExpr(t *testing.T) {
// throwFail(t, AssertIs(num, 3)) // throwFail(t, AssertIs(num, 3))
} }
func TestSpecifyIndex(t *testing.T) {
var index *Index
index = &Index{
F1: 1,
F2: 2,
}
_, _ = dORM.Insert(index)
throwFailNow(t, AssertIs(index.Id, 1))
index = &Index{
F1: 3,
F2: 4,
}
_, _ = dORM.Insert(index)
throwFailNow(t, AssertIs(index.Id, 2))
_ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`index_f1`).One(index)
throwFailNow(t, AssertIs(index.F2, 2))
_ = dORM.QueryTable(&Index{}).Filter(`f2`, `4`).UseIndex(`index_f2`).One(index)
throwFailNow(t, AssertIs(index.F1, 3))
_ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`index_f1`, `index_f2`).One(index)
throwFailNow(t, AssertIs(index.F2, 2))
}
func TestOperators(t *testing.T) { func TestOperators(t *testing.T) {
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
num, err := qs.Filter("user_name", "slene").Count() num, err := qs.Filter("user_name", "slene").Count()
@ -1279,24 +1308,32 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true) num, err = dORM.LoadRelated(&user, "Posts", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1) num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.Limit(1))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.OrderBy("-Id"))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.Limit(1),
hints.Offset(1),
hints.OrderBy("Id"))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
@ -1318,7 +1355,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User == nil, false))
throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) throwFailNow(t, AssertIs(profile.User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&profile, "User", true) num, err = dORM.LoadRelated(&profile, "User", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User == nil, false))
@ -1335,7 +1372,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile == nil, false))
throwFailNow(t, AssertIs(user.Profile.Age, 30)) throwFailNow(t, AssertIs(user.Profile.Age, 30))
num, err = dORM.LoadRelated(&user, "Profile", true) num, err = dORM.LoadRelated(&user, "Profile", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile == nil, false))
@ -1355,7 +1392,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User == nil, false))
throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) throwFailNow(t, AssertIs(post.User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&post, "User", true) num, err = dORM.LoadRelated(&post, "User", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User == nil, false))
@ -1375,7 +1412,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(len(post.Tags), 2))
throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) throwFailNow(t, AssertIs(post.Tags[0].Name, "golang"))
num, err = dORM.LoadRelated(&post, "Tags", true) num, err = dORM.LoadRelated(&post, "Tags", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(len(post.Tags), 2))
@ -1396,7 +1433,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true) num, err = dORM.LoadRelated(&tag, "Posts", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
"reflect" "reflect"
"time" "time"
) )
@ -175,14 +176,14 @@ type DQL interface {
// example: // example:
// Ormer.LoadRelated(post,"Tags") // Ormer.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...} // for _,tag := range post.Tags{...}
// args[0] bool true useDefaultRelsDepth ; false depth 0 // hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0
// args[0] int loadRelationDepth // hints.RelDepth loadRelationDepth
// args[1] int limit default limit 1000 // hints.Limit limit default limit 1000
// args[2] int offset default offset 0 // hints.Offset int offset default offset 0
// args[3] string order for example : "-Id" // hints.OrderBy string order for example : "-Id"
// make sure the relation is defined in model struct tags. // make sure the relation is defined in model struct tags.
LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error)
LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error)
// create a models to models queryer // create a models to models queryer
// for example: // for example:
@ -282,6 +283,21 @@ type QuerySeter interface {
// for example: // for example:
// qs.OrderBy("-status") // qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter OrderBy(exprs ...string) QuerySeter
// add FORCE INDEX expression.
// for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
ForceIndex(indexes ...string) QuerySeter
// add USE INDEX expression.
// for example:
// qs.UseIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
UseIndex(indexes ...string) QuerySeter
// add IGNORE INDEX expression.
// for example:
// qs.IgnoreIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
IgnoreIndex(indexes ...string) QuerySeter
// set relation model to query together. // set relation model to query together.
// it will query relation models and assign to parent model. // it will query relation models and assign to parent model.
// for example: // for example:
@ -527,24 +543,27 @@ type txEnder interface {
// base database struct // base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool
OperatorSQL(string) string OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)
@ -559,4 +578,6 @@ type dbBaser interface {
IndexExists(dbQuerier, string, string) bool IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error setval(dbQuerier, *modelInfo, []string) error
GenerateSpecifyIndex(tableName string,useIndex int ,indexes []string) string
} }