Beego/pkg/orm/filter/bean/default_value_filter.go

137 lines
4.2 KiB
Go

// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bean
import (
"context"
"reflect"
"strings"
"github.com/astaxie/beego/pkg/bean"
"github.com/astaxie/beego/pkg/logs"
"github.com/astaxie/beego/pkg/orm"
)
// DefaultValueFilterChainBuilder only works for InsertXXX method,
// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods.
// so we won't handle those two methods unless you set includeInsertOrUpdate to true
// And if the element is not pointer, this filter doesn't work
type DefaultValueFilterChainBuilder struct {
factory bean.AutoWireBeanFactory
compatibleWithOldStyle bool
// only the includeInsertOrUpdate is true, this filter will handle those two methods
includeInsertOrUpdate bool
}
// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder
// In beego v1.x, the default value config looks like orm:default(xxxx)
// But the default value in 2.x is default:xxx
// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle
func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter,
includeInsertOrUpdate bool,
compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder {
factory := bean.NewTagAutoWireBeanFactory()
if compatibleWithOldStyle {
newParser := factory.FieldTagParser
factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata {
if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" {
return newParser(field)
} else {
res := &bean.FieldMetadata{}
ormMeta := field.Tag.Get("orm")
ormMetaParts := strings.Split(ormMeta, ";")
for _, p := range ormMetaParts {
if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") {
res.DftValue = p[8 : len(p)-1]
}
}
return res
}
}
}
for k, v := range typeAdapters {
factory.Adapters[k] = v
}
return &DefaultValueFilterChainBuilder{
factory: factory,
compatibleWithOldStyle: compatibleWithOldStyle,
includeInsertOrUpdate: includeInsertOrUpdate,
}
}
func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
switch inv.Method {
case "Insert", "InsertWithCtx":
d.handleInsert(ctx, inv)
break
case "InsertOrUpdate", "InsertOrUpdateWithCtx":
d.handleInsertOrUpdate(ctx, inv)
break
case "InsertMulti", "InsertMultiWithCtx":
d.handleInsertMulti(ctx, inv)
break
}
return next(ctx, inv)
}
}
func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) {
d.setDefaultValue(ctx, inv.Args[0])
}
func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) {
if d.includeInsertOrUpdate {
ins := inv.Args[0]
if ins == nil {
return
}
pkName := inv.GetPkFieldName()
pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName)
if pkField.IsZero() {
d.setDefaultValue(ctx, ins)
}
}
}
func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) {
mds := inv.Args[1]
if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice {
// do nothing
return
}
mdsArr := reflect.Indirect(reflect.ValueOf(mds))
for i := 0; i < mdsArr.Len(); i++ {
d.setDefaultValue(ctx, mdsArr.Index(i).Interface())
}
logs.Warn("%v", mdsArr.Index(0).Interface())
}
func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) {
err := d.factory.AutoWire(ctx, nil, ins)
if err != nil {
logs.Error("try to wire the bean for orm.Insert failed. "+
"the default value is not set: %v, ", err)
}
}