1
0
mirror of https://github.com/astaxie/beego.git synced 2025-06-13 08:50:39 +00:00

Bean: Support autowire by tag

Orm: Support default value filter
This commit is contained in:
Ming Deng
2020-08-13 14:14:10 +08:00
parent a1b7fd3c93
commit bdec93986b
14 changed files with 712 additions and 0 deletions

View File

@ -0,0 +1,136 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"reflect"
"strings"
"github.com/astaxie/beego/pkg/bean"
"github.com/astaxie/beego/pkg/logs"
"github.com/astaxie/beego/pkg/orm"
)
// DefaultValueFilterChainBuilder only works for InsertXXX method,
// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods.
// so we won't handle those two methods unless you set includeInsertOrUpdate to true
// And if the element is not pointer, this filter doesn't work
type DefaultValueFilterChainBuilder struct {
factory bean.AutoWireBeanFactory
compatibleWithOldStyle bool
// only the includeInsertOrUpdate is true, this filter will handle those two methods
includeInsertOrUpdate bool
}
// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder
// In beego v1.x, the default value config looks like orm:default(xxxx)
// But the default value in 2.x is default:xxx
// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle
func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter,
includeInsertOrUpdate bool,
compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder {
factory := bean.NewTagAutoWireBeanFactory()
if compatibleWithOldStyle {
newParser := factory.FieldTagParser
factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata {
if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" {
return newParser(field)
} else {
res := &bean.FieldMetadata{}
ormMeta := field.Tag.Get("orm")
ormMetaParts := strings.Split(ormMeta, ";")
for _, p := range ormMetaParts {
if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") {
res.DftValue = p[8 : len(p)-1]
}
}
return res
}
}
}
for k, v := range typeAdapters {
factory.Adapters[k] = v
}
return &DefaultValueFilterChainBuilder{
factory: factory,
compatibleWithOldStyle: compatibleWithOldStyle,
includeInsertOrUpdate: includeInsertOrUpdate,
}
}
func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
return func(ctx context.Context, inv *orm.Invocation) {
switch inv.Method {
case "Insert", "InsertWithCtx":
d.handleInsert(ctx, inv)
break
case "InsertOrUpdate", "InsertOrUpdateWithCtx":
d.handleInsertOrUpdate(ctx, inv)
break
case "InsertMulti", "InsertMultiWithCtx":
d.handleInsertMulti(ctx, inv)
break
}
next(ctx, inv)
}
}
func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) {
d.setDefaultValue(ctx, inv.Args[0])
}
func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) {
if d.includeInsertOrUpdate {
ins := inv.Args[0]
if ins == nil {
return
}
pkName := inv.GetPkFieldName()
pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName)
if pkField.IsZero() {
d.setDefaultValue(ctx, ins)
}
}
}
func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) {
mds := inv.Args[1]
if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice {
// do nothing
return
}
mdsArr := reflect.Indirect(reflect.ValueOf(mds))
for i := 0; i < mdsArr.Len(); i++ {
d.setDefaultValue(ctx, mdsArr.Index(i).Interface())
}
logs.Warn("%v", mdsArr.Index(0).Interface())
}
func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) {
err := d.factory.AutoWire(ctx, nil, ins)
if err != nil {
logs.Error("try to wire the bean for orm.Insert failed. "+
"the default value is not set: %v, ", err)
}
}

View File

@ -0,0 +1,73 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/orm"
)
func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) {
builder := NewDefaultValueFilterChainBuilder(nil, true, true)
o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain)
// test insert
entity := &DefaultValueTestEntity{}
_, _ = o.Insert(entity)
assert.Equal(t, 12, entity.Age)
assert.Equal(t, 13, entity.AgeInOldStyle)
assert.Equal(t, 0, entity.AgeIgnore)
// test InsertOrUpdate
entity = &DefaultValueTestEntity{}
orm.RegisterModel(entity)
_, _ = o.InsertOrUpdate(entity)
assert.Equal(t, 12, entity.Age)
assert.Equal(t, 13, entity.AgeInOldStyle)
// we won't set the default value because we find the pk is not Zero value
entity.Id = 3
entity.AgeInOldStyle = 0
_, _ = o.InsertOrUpdate(entity)
assert.Equal(t, 0, entity.AgeInOldStyle)
entity = &DefaultValueTestEntity{}
// the entity is not array, it will be ignored
_, _ = o.InsertMulti(3, entity)
assert.Equal(t, 0, entity.Age)
assert.Equal(t, 0, entity.AgeInOldStyle)
_, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity})
assert.Equal(t, 12, entity.Age)
assert.Equal(t, 13, entity.AgeInOldStyle)
}
type defaultValueTestOrm struct {
orm.DoNothingOrm
}
type DefaultValueTestEntity struct {
Id int`orm:pk`
Age int `default:"12"`
AgeInOldStyle int `orm:"default(13);bee()"`
AgeIgnore int
}

View File

@ -46,3 +46,12 @@ func (inv *Invocation) GetTableName() string {
func (inv *Invocation) execute() {
inv.f()
}
// GetPkFieldName return the primary key of this table
// if not found, "" is returned
func (inv *Invocation) GetPkFieldName() string {
if inv.mi.fields.pk != nil {
return inv.mi.fields.pk.name
}
return ""
}