1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 01:00:54 +00:00

Orm filter support

This commit is contained in:
Ming Deng 2020-08-07 13:45:24 +00:00
parent 3382a5baa1
commit 08cec9178f
10 changed files with 1391 additions and 4 deletions

View File

@ -0,0 +1,134 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDoNothingOrm(t *testing.T) {
o := &DoNothingOrm{}
err := o.DoTxWithCtxAndOpts(nil, nil, nil)
assert.Nil(t, err)
err = o.DoTxWithCtx(nil, nil)
assert.Nil(t, err)
err = o.DoTx(nil)
assert.Nil(t, err)
err = o.DoTxWithOpts(nil, nil)
assert.Nil(t, err)
assert.Nil(t, o.Driver())
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
assert.Nil(t, o.QueryM2M(nil, ""))
assert.Nil(t, o.ReadWithCtx(nil, nil))
assert.Nil(t, o.Read(nil))
txOrm, err := o.BeginWithCtxAndOpts(nil, nil)
assert.Nil(t, err)
assert.Nil(t, txOrm)
txOrm, err = o.BeginWithCtx(nil)
assert.Nil(t, err)
assert.Nil(t, txOrm)
txOrm, err = o.BeginWithOpts(nil)
assert.Nil(t, err)
assert.Nil(t, txOrm)
txOrm, err = o.Begin()
assert.Nil(t, err)
assert.Nil(t, txOrm)
assert.Nil(t, o.RawWithCtx(nil, ""))
assert.Nil(t, o.Raw(""))
i, err := o.InsertMulti(0, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.Insert(nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.InsertWithCtx(nil, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.InsertOrUpdateWithCtx(nil, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.InsertOrUpdate(nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.InsertMultiWithCtx(nil, 0, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.LoadRelatedWithCtx(nil, nil, "")
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.LoadRelated(nil, "")
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
assert.Nil(t, o.QueryTable(nil))
assert.Nil(t, o.Read(nil))
assert.Nil(t, o.ReadWithCtx(nil, nil))
assert.Nil(t, o.ReadForUpdateWithCtx(nil, nil))
assert.Nil(t, o.ReadForUpdate(nil))
ok, i, err := o.ReadOrCreate(nil, "")
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.False(t, ok)
ok, i, err = o.ReadOrCreateWithCtx(nil, nil, "")
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.False(t, ok)
i, err = o.Delete(nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.DeleteWithCtx(nil, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.Update(nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
i, err = o.UpdateWithCtx(nil, nil)
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.Nil(t, o.DBStats())
to := &DoNothingTxOrm{}
assert.Nil(t, to.Commit())
assert.Nil(t, to.Rollback())
}

178
pkg/orm/do_nothing_orm.go Normal file
View File

@ -0,0 +1,178 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
)
// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation
// I think golang mocking interface is hard to use
// this may help you to integrate with Ormer
var _ Ormer = new(DoNothingOrm)
type DoNothingOrm struct {
}
func (d *DoNothingOrm) Read(md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadForUpdate(md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
return false, 0, nil
}
func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
return false, 0, nil
}
func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
return nil
}
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
return nil
}
func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return nil
}
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
return nil
}
func (d *DoNothingOrm) DBStats() *sql.DBStats {
return nil
}
func (d *DoNothingOrm) Insert(md interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertMulti(bulk int, mds interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Update(md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Delete(md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Raw(query string, args ...interface{}) RawSeter {
return nil
}
func (d *DoNothingOrm) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
return nil
}
func (d *DoNothingOrm) Driver() Driver {
return nil
}
func (d *DoNothingOrm) Begin() (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) DoTx(task func(txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error {
return nil
}
// DoNothingTxOrm is similar with DoNothingOrm, usually you use it to test
type DoNothingTxOrm struct {
DoNothingOrm
}
func (d *DoNothingTxOrm) Commit() error {
return nil
}
func (d *DoNothingTxOrm) Rollback() error {
return nil
}

32
pkg/orm/filter.go Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
)
type FilterChain func(next Filter) Filter
type Filter func(ctx context.Context, inv *Invocation)
var globalFilterChains = make([]FilterChain, 0, 4)
// AddGlobalFilterChain adds a new FilterChain
// All orm instances built after this invocation will use this filterChain,
// but instances built before this invocation will not be affected
func AddGlobalFilterChain(filterChain FilterChain) {
globalFilterChains = append(globalFilterChains, filterChain)
}

View File

@ -0,0 +1,519 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"reflect"
"time"
)
const TxNameKey = "TxName"
type filterOrmDecorator struct {
ormer
TxBeginner
TxCommitter
root Filter
insideTx bool
txStartTime time.Time
txName string
}
func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer {
res := &filterOrmDecorator{
ormer: delegate,
TxBeginner: delegate,
root: func(ctx context.Context, inv *Invocation) {
inv.execute()
},
}
for i := len(filterChains) - 1; i >= 0; i-- {
node := filterChains[i]
res.root = node(res.root)
}
return res
}
func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrmer {
res := &filterOrmDecorator{
ormer: delegate,
TxCommitter: delegate,
root: root,
insideTx: true,
txStartTime: time.Now(),
txName: txName,
}
return res
}
func (f *filterOrmDecorator) Read(md interface{}, cols ...string) error {
return f.ReadWithCtx(context.Background(), md, cols...)
}
func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) (err error) {
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "ReadWithCtx",
Args: []interface{}{md, cols},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
err = f.ormer.ReadWithCtx(ctx, md, cols...)
},
}
f.root(ctx, inv)
return err
}
func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error {
return f.ReadForUpdateWithCtx(context.Background(), md, cols...)
}
func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
var err error
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "ReadForUpdateWithCtx",
Args: []interface{}{md, cols},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
err = f.ormer.ReadForUpdateWithCtx(ctx, md, cols...)
},
}
f.root(ctx, inv)
return err
}
func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
return f.ReadOrCreateWithCtx(context.Background(), md, col1, cols...)
}
func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
var (
ok bool
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "ReadOrCreateWithCtx",
Args: []interface{}{md, col1, cols},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
ok, res, err = f.ormer.ReadOrCreateWithCtx(ctx, md, col1, cols...)
},
}
f.root(ctx, inv)
return ok, res, err
}
func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
return f.LoadRelatedWithCtx(context.Background(), md, name, args...)
}
func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) {
var (
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "LoadRelatedWithCtx",
Args: []interface{}{md, name, args},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.LoadRelatedWithCtx(ctx, md, name, args...)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
return f.QueryM2MWithCtx(context.Background(), md, name)
}
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
var (
res QueryM2Mer
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "QueryM2MWithCtx",
Args: []interface{}{md, name},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res = f.ormer.QueryM2MWithCtx(ctx, md, name)
},
}
f.root(ctx, inv)
return res
}
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
}
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
var (
res QuerySeter
name string
md interface{}
mi *modelInfo
)
if table, ok := ptrStructOrTableName.(string); ok {
name = table
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
md = ptrStructOrTableName
}
if m, ok := modelCache.getByFullName(name); ok {
mi = m
}
inv := &Invocation{
Method: "QueryTableWithCtx",
Args: []interface{}{ptrStructOrTableName},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
Md: md,
mi: mi,
f: func() {
res = f.ormer.QueryTableWithCtx(ctx, ptrStructOrTableName)
},
}
f.root(ctx, inv)
return res
}
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
var (
res *sql.DBStats
)
inv := &Invocation{
Method: "DBStats",
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res = f.ormer.DBStats()
},
}
f.root(context.Background(), inv)
return res
}
func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) {
return f.InsertWithCtx(context.Background(), md)
}
func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
var (
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "InsertWithCtx",
Args: []interface{}{md},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.InsertWithCtx(ctx, md)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
return f.InsertOrUpdateWithCtx(context.Background(), md, colConflitAndArgs...)
}
func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
var (
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "InsertOrUpdateWithCtx",
Args: []interface{}{md, colConflitAndArgs},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.InsertOrUpdateWithCtx(ctx, md, colConflitAndArgs...)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, error) {
return f.InsertMultiWithCtx(context.Background(), bulk, mds)
}
// InsertMultiWithCtx uses the first element's model info
func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
var (
res int64
err error
md interface{}
mi *modelInfo
)
sind := reflect.Indirect(reflect.ValueOf(mds))
if (sind.Kind() == reflect.Array || sind.Kind() == reflect.Slice) && sind.Len() > 0 {
ind := reflect.Indirect(sind.Index(0))
md = ind.Interface()
mi, _ = modelCache.getByMd(md)
}
inv := &Invocation{
Method: "InsertMultiWithCtx",
Args: []interface{}{bulk, mds},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.InsertMultiWithCtx(ctx, bulk, mds)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, error) {
return f.UpdateWithCtx(context.Background(), md, cols...)
}
func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
var (
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "UpdateWithCtx",
Args: []interface{}{md, cols},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.UpdateWithCtx(ctx, md, cols...)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, error) {
return f.DeleteWithCtx(context.Background(), md, cols...)
}
func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
var (
res int64
err error
)
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "DeleteWithCtx",
Args: []interface{}{md, cols},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.ormer.DeleteWithCtx(ctx, md, cols...)
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter {
return f.RawWithCtx(context.Background(), query, args...)
}
func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
var (
res RawSeter
)
inv := &Invocation{
Method: "RawWithCtx",
Args: []interface{}{query, args},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res = f.ormer.RawWithCtx(ctx, query, args...)
},
}
f.root(ctx, inv)
return res
}
func (f *filterOrmDecorator) Driver() Driver {
var (
res Driver
)
inv := &Invocation{
Method: "Driver",
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res = f.ormer.Driver()
},
}
f.root(context.Background(), inv)
return res
}
func (f *filterOrmDecorator) Begin() (TxOrmer, error) {
return f.BeginWithCtxAndOpts(context.Background(), nil)
}
func (f *filterOrmDecorator) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
return f.BeginWithCtxAndOpts(ctx, nil)
}
func (f *filterOrmDecorator) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
return f.BeginWithCtxAndOpts(context.Background(), opts)
}
func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
var (
res TxOrmer
err error
)
inv := &Invocation{
Method: "BeginWithCtxAndOpts",
Args: []interface{}{opts},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func() {
res, err = f.TxBeginner.BeginWithCtxAndOpts(ctx, opts)
res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(ctx))
},
}
f.root(ctx, inv)
return res, err
}
func (f *filterOrmDecorator) DoTx(task func(txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(context.Background(), nil, task)
}
func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(ctx, nil, task)
}
func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(context.Background(), opts, task)
}
func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error {
var (
err error
)
inv := &Invocation{
Method: "DoTxWithCtxAndOpts",
Args: []interface{}{opts, task},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
TxName: getTxNameFromCtx(ctx),
f: func() {
err = f.TxBeginner.DoTxWithCtxAndOpts(ctx, opts, task)
},
}
f.root(ctx, inv)
return err
}
func (f *filterOrmDecorator) Commit() error {
var (
err error
)
inv := &Invocation{
Method: "Commit",
Args: []interface{}{},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
TxName: f.txName,
f: func() {
err = f.TxCommitter.Commit()
},
}
f.root(context.Background(), inv)
return err
}
func (f *filterOrmDecorator) Rollback() error {
var (
err error
)
inv := &Invocation{
Method: "Rollback",
Args: []interface{}{},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
TxName: f.txName,
f: func() {
err = f.TxCommitter.Rollback()
},
}
f.root(context.Background(), inv)
return err
}
func getTxNameFromCtx(ctx context.Context) string {
txName := ""
if n, ok := ctx.Value(TxNameKey).(string); ok {
txName = n
}
return txName
}

View File

@ -0,0 +1,432 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilterOrmDecorator_Read(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "ReadWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
next(ctx, inv)
}
})
fte := &FilterTestEntity{}
err := od.Read(fte)
assert.NotNil(t, err)
assert.Equal(t, "read error", err.Error())
}
func TestFilterOrmDecorator_BeginTx(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
if inv.Method == "BeginWithCtxAndOpts" {
assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
assert.False(t, inv.InsideTx)
} else if inv.Method == "Commit" {
assert.Equal(t, 0, len(inv.Args))
assert.Equal(t, "Commit_tx", inv.TxName)
assert.Equal(t, "", inv.GetTableName())
assert.True(t, inv.InsideTx)
} else if inv.Method == "Rollback" {
assert.Equal(t, 0, len(inv.Args))
assert.Equal(t, "Rollback_tx", inv.TxName)
assert.Equal(t, "", inv.GetTableName())
assert.True(t, inv.InsideTx)
} else {
t.Fail()
}
next(ctx, inv)
}
})
to, err := od.Begin()
assert.True(t, validateBeginResult(t, to, err))
to, err = od.BeginWithOpts(nil)
assert.True(t, validateBeginResult(t, to, err))
ctx := context.WithValue(context.Background(), TxNameKey, "Commit_tx")
to, err = od.BeginWithCtx(ctx)
assert.True(t, validateBeginResult(t, to, err))
err = to.Commit()
assert.NotNil(t, err)
assert.Equal(t, "commit", err.Error())
ctx = context.WithValue(context.Background(), TxNameKey, "Rollback_tx")
to, err = od.BeginWithCtxAndOpts(ctx, nil)
assert.True(t, validateBeginResult(t, to, err))
err = to.Rollback()
assert.NotNil(t, err)
assert.Equal(t, "rollback", err.Error())
}
func TestFilterOrmDecorator_DBStats(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DBStats", inv.Method)
assert.Equal(t, 0, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
next(ctx, inv)
}
})
res := od.DBStats()
assert.NotNil(t, res)
assert.Equal(t, -1, res.MaxOpenConnections)
}
func TestFilterOrmDecorator_Delete(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DeleteWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
next(ctx, inv)
}
})
res, err := od.Delete(&FilterTestEntity{})
assert.NotNil(t, err)
assert.Equal(t, "delete error", err.Error())
assert.Equal(t, int64(-2), res)
}
func TestFilterOrmDecorator_DoTx(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
err := od.DoTx(func(txOrm TxOrmer) error {
return errors.New("tx error")
})
assert.NotNil(t, err)
assert.Equal(t, "tx error", err.Error())
err = od.DoTxWithCtx(context.Background(), func(txOrm TxOrmer) error {
return errors.New("tx ctx error")
})
assert.NotNil(t, err)
assert.Equal(t, "tx ctx error", err.Error())
err = od.DoTxWithOpts(nil, func(txOrm TxOrmer) error {
return errors.New("tx opts error")
})
assert.NotNil(t, err)
assert.Equal(t, "tx opts error", err.Error())
od = NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
assert.Equal(t, "do tx name", inv.TxName)
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
ctx := context.WithValue(context.Background(), TxNameKey, "do tx name")
err = od.DoTxWithCtxAndOpts(ctx, nil, func(txOrm TxOrmer) error {
return errors.New("tx ctx opts error")
})
assert.NotNil(t, err)
assert.Equal(t, "tx ctx opts error", err.Error())
}
func TestFilterOrmDecorator_Driver(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "Driver", inv.Method)
assert.Equal(t, 0, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
res := od.Driver()
assert.Nil(t, res)
}
func TestFilterOrmDecorator_Insert(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "InsertWithCtx", inv.Method)
assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
i, err := od.Insert(&FilterTestEntity{})
assert.NotNil(t, err)
assert.Equal(t, "insert error", err.Error())
assert.Equal(t, int64(100), i)
}
func TestFilterOrmDecorator_InsertMulti(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "InsertMultiWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
bulk := []*FilterTestEntity{&FilterTestEntity{}, &FilterTestEntity{}}
i, err := od.InsertMulti(2, bulk)
assert.NotNil(t, err)
assert.Equal(t, "insert multi error", err.Error())
assert.Equal(t, int64(2), i)
}
func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "InsertOrUpdateWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
i, err := od.InsertOrUpdate(&FilterTestEntity{})
assert.NotNil(t, err)
assert.Equal(t, "insert or update error", err.Error())
assert.Equal(t, int64(1), i)
}
func TestFilterOrmDecorator_LoadRelated(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "LoadRelatedWithCtx", inv.Method)
assert.Equal(t, 3, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
i, err := od.LoadRelated(&FilterTestEntity{}, "hello")
assert.NotNil(t, err)
assert.Equal(t, "load related error", err.Error())
assert.Equal(t, int64(99), i)
}
func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
res := od.QueryM2M(&FilterTestEntity{}, "hello")
assert.Nil(t, res)
}
func TestFilterOrmDecorator_QueryTable(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "QueryTableWithCtx", inv.Method)
assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
res := od.QueryTable(&FilterTestEntity{})
assert.Nil(t, res)
}
func TestFilterOrmDecorator_Raw(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "RawWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
res := od.Raw("hh")
assert.Nil(t, res)
}
func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "ReadForUpdateWithCtx", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
err := od.ReadForUpdate(&FilterTestEntity{})
assert.NotNil(t, err)
assert.Equal(t, "read for update error", err.Error())
}
func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) {
register()
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "ReadOrCreateWithCtx", inv.Method)
assert.Equal(t, 3, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
next(ctx, inv)
}
})
ok, i, err := od.ReadOrCreate(&FilterTestEntity{}, "name")
assert.NotNil(t, err)
assert.Equal(t, "read or create error", err.Error())
assert.True(t, ok)
assert.Equal(t, int64(13), i)
}
// filterMockOrm is only used in this test file
type filterMockOrm struct {
DoNothingOrm
}
func (f *filterMockOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
return true, 13, errors.New("read or create error")
}
func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return errors.New("read for update error")
}
func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) {
return 99, errors.New("load related error")
}
func (f *filterMockOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
return 1, errors.New("insert or update error")
}
func (f *filterMockOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
return 2, errors.New("insert multi error")
}
func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
return 100, errors.New("insert error")
}
func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error {
return task(nil)
}
func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
return -2, errors.New("delete error")
}
func (f *filterMockOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
return &filterMockOrm{}, errors.New("begin tx")
}
func (f *filterMockOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return errors.New("read error")
}
func (f *filterMockOrm) Commit() error {
return errors.New("commit")
}
func (f *filterMockOrm) Rollback() error {
return errors.New("rollback")
}
func (f *filterMockOrm) DBStats() *sql.DBStats {
return &sql.DBStats{
MaxOpenConnections: -1,
}
}
func validateBeginResult(t *testing.T, to TxOrmer, err error) bool {
assert.NotNil(t, err)
assert.Equal(t, "begin tx", err.Error())
_, ok := to.(*filterOrmDecorator).TxCommitter.(*filterMockOrm)
assert.True(t, ok)
return true
}
var filterTestEntityRegisterOnce sync.Once
type FilterTestEntity struct {
ID int
Name string
}
func register() {
filterTestEntityRegisterOnce.Do(func() {
RegisterModel(&FilterTestEntity{})
})
}
func (f *FilterTestEntity) TableName() string {
return "FILTER_TEST"
}

31
pkg/orm/filter_test.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddGlobalFilterChain(t *testing.T) {
AddGlobalFilterChain(func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) {
}
})
assert.Equal(t, 1, len(globalFilterChains))
}

48
pkg/orm/invocation.go Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"time"
)
// Invocation represents an "Orm" invocation
type Invocation struct {
Method string
// Md may be nil in some cases. It depends on method
Md interface{}
// the args are all arguments except context.Context
Args []interface{}
mi *modelInfo
// f is the Orm operation
f func()
// insideTx indicates whether this is inside a transaction
InsideTx bool
TxStartTime time.Time
TxName string
}
func (inv *Invocation) GetTableName() string {
if inv.mi != nil{
return inv.mi.table
}
return ""
}
func (inv *Invocation) execute() {
inv.f()
}

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"reflect"
"sync" "sync"
) )
@ -73,6 +74,14 @@ func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
return return
} }
func (mc *_modelCache) getByMd(md interface{}) (*modelInfo, bool) {
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
name := getFullName(typ)
return mc.getByFullName(name)
}
// set model info to collection // set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table] mii := mc.cache[table]

View File

@ -2486,3 +2486,5 @@ func TestInsertOrUpdate(t *testing.T) {
throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status))
} }
} }

View File

@ -204,17 +204,19 @@ type DriverGetter interface {
Driver() Driver Driver() Driver
} }
type Ormer interface { type ormer interface {
DQL DQL
DML DML
DriverGetter DriverGetter
}
type Ormer interface {
ormer
TxBeginner TxBeginner
} }
type TxOrmer interface { type TxOrmer interface {
DQL ormer
DML
DriverGetter
TxCommitter TxCommitter
} }