mirror of
https://github.com/astaxie/beego.git
synced 2025-06-13 08:50:39 +00:00
Bean: Support autowire by tag
Orm: Support default value filter
This commit is contained in:
136
pkg/orm/filter/bean/default_value_filter.go
Normal file
136
pkg/orm/filter/bean/default_value_filter.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bean
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/pkg/bean"
|
||||
"github.com/astaxie/beego/pkg/logs"
|
||||
"github.com/astaxie/beego/pkg/orm"
|
||||
)
|
||||
|
||||
// DefaultValueFilterChainBuilder only works for InsertXXX method,
|
||||
// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods.
|
||||
// so we won't handle those two methods unless you set includeInsertOrUpdate to true
|
||||
// And if the element is not pointer, this filter doesn't work
|
||||
type DefaultValueFilterChainBuilder struct {
|
||||
factory bean.AutoWireBeanFactory
|
||||
compatibleWithOldStyle bool
|
||||
|
||||
// only the includeInsertOrUpdate is true, this filter will handle those two methods
|
||||
includeInsertOrUpdate bool
|
||||
}
|
||||
|
||||
// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder
|
||||
// In beego v1.x, the default value config looks like orm:default(xxxx)
|
||||
// But the default value in 2.x is default:xxx
|
||||
// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle
|
||||
func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter,
|
||||
includeInsertOrUpdate bool,
|
||||
compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder {
|
||||
factory := bean.NewTagAutoWireBeanFactory()
|
||||
|
||||
if compatibleWithOldStyle {
|
||||
newParser := factory.FieldTagParser
|
||||
factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata {
|
||||
if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" {
|
||||
return newParser(field)
|
||||
} else {
|
||||
res := &bean.FieldMetadata{}
|
||||
ormMeta := field.Tag.Get("orm")
|
||||
ormMetaParts := strings.Split(ormMeta, ";")
|
||||
for _, p := range ormMetaParts {
|
||||
if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") {
|
||||
res.DftValue = p[8 : len(p)-1]
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range typeAdapters {
|
||||
factory.Adapters[k] = v
|
||||
}
|
||||
|
||||
return &DefaultValueFilterChainBuilder{
|
||||
factory: factory,
|
||||
compatibleWithOldStyle: compatibleWithOldStyle,
|
||||
includeInsertOrUpdate: includeInsertOrUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
||||
return func(ctx context.Context, inv *orm.Invocation) {
|
||||
switch inv.Method {
|
||||
case "Insert", "InsertWithCtx":
|
||||
d.handleInsert(ctx, inv)
|
||||
break
|
||||
case "InsertOrUpdate", "InsertOrUpdateWithCtx":
|
||||
d.handleInsertOrUpdate(ctx, inv)
|
||||
break
|
||||
case "InsertMulti", "InsertMultiWithCtx":
|
||||
d.handleInsertMulti(ctx, inv)
|
||||
break
|
||||
}
|
||||
next(ctx, inv)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) {
|
||||
d.setDefaultValue(ctx, inv.Args[0])
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) {
|
||||
if d.includeInsertOrUpdate {
|
||||
ins := inv.Args[0]
|
||||
if ins == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkName := inv.GetPkFieldName()
|
||||
pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName)
|
||||
|
||||
if pkField.IsZero() {
|
||||
d.setDefaultValue(ctx, ins)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) {
|
||||
mds := inv.Args[1]
|
||||
|
||||
if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice {
|
||||
// do nothing
|
||||
return
|
||||
}
|
||||
|
||||
mdsArr := reflect.Indirect(reflect.ValueOf(mds))
|
||||
for i := 0; i < mdsArr.Len(); i++ {
|
||||
d.setDefaultValue(ctx, mdsArr.Index(i).Interface())
|
||||
}
|
||||
logs.Warn("%v", mdsArr.Index(0).Interface())
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) {
|
||||
err := d.factory.AutoWire(ctx, nil, ins)
|
||||
if err != nil {
|
||||
logs.Error("try to wire the bean for orm.Insert failed. "+
|
||||
"the default value is not set: %v, ", err)
|
||||
}
|
||||
}
|
73
pkg/orm/filter/bean/default_value_filter_test.go
Normal file
73
pkg/orm/filter/bean/default_value_filter_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bean
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/orm"
|
||||
)
|
||||
|
||||
func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
builder := NewDefaultValueFilterChainBuilder(nil, true, true)
|
||||
o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain)
|
||||
|
||||
// test insert
|
||||
entity := &DefaultValueTestEntity{}
|
||||
_, _ = o.Insert(entity)
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
assert.Equal(t, 0, entity.AgeIgnore)
|
||||
|
||||
// test InsertOrUpdate
|
||||
entity = &DefaultValueTestEntity{}
|
||||
orm.RegisterModel(entity)
|
||||
|
||||
|
||||
_, _ = o.InsertOrUpdate(entity)
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
|
||||
// we won't set the default value because we find the pk is not Zero value
|
||||
entity.Id = 3
|
||||
entity.AgeInOldStyle = 0
|
||||
_, _ = o.InsertOrUpdate(entity)
|
||||
assert.Equal(t, 0, entity.AgeInOldStyle)
|
||||
|
||||
entity = &DefaultValueTestEntity{}
|
||||
|
||||
// the entity is not array, it will be ignored
|
||||
_, _ = o.InsertMulti(3, entity)
|
||||
assert.Equal(t, 0, entity.Age)
|
||||
assert.Equal(t, 0, entity.AgeInOldStyle)
|
||||
|
||||
_, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity})
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
|
||||
}
|
||||
|
||||
type defaultValueTestOrm struct {
|
||||
orm.DoNothingOrm
|
||||
}
|
||||
|
||||
type DefaultValueTestEntity struct {
|
||||
Id int`orm:pk`
|
||||
Age int `default:"12"`
|
||||
AgeInOldStyle int `orm:"default(13);bee()"`
|
||||
AgeIgnore int
|
||||
}
|
@ -46,3 +46,12 @@ func (inv *Invocation) GetTableName() string {
|
||||
func (inv *Invocation) execute() {
|
||||
inv.f()
|
||||
}
|
||||
|
||||
// GetPkFieldName return the primary key of this table
|
||||
// if not found, "" is returned
|
||||
func (inv *Invocation) GetPkFieldName() string {
|
||||
if inv.mi.fields.pk != nil {
|
||||
return inv.mi.fields.pk.name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
Reference in New Issue
Block a user