diff --git a/go.mod b/go.mod index 3ad8576a..b3f2c2e7 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( golang.org/x/tools v0.0.0-20200117065230-39095c1d176c google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 + github.com/pkg/errors v0.9.1 ) replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 diff --git a/go.sum b/go.sum index 12b76333..a54afad6 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -185,6 +187,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -219,6 +222,8 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/pkg/bean/context.go b/pkg/bean/context.go new file mode 100644 index 00000000..93261628 --- /dev/null +++ b/pkg/bean/context.go @@ -0,0 +1,21 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// ApplicationContext define for future +// when we decide to support DI, IoC, this will be core API +type ApplicationContext interface { + +} diff --git a/pkg/bean/doc.go b/pkg/bean/doc.go new file mode 100644 index 00000000..212e8aaf --- /dev/null +++ b/pkg/bean/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// bean is a basic package +// it should not depend on other modules except common module, log module and config module +package bean diff --git a/pkg/bean/factory.go b/pkg/bean/factory.go new file mode 100644 index 00000000..698474c4 --- /dev/null +++ b/pkg/bean/factory.go @@ -0,0 +1,25 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// AutoWireBeanFactory wire the bean based on ApplicationContext and context.Context +type AutoWireBeanFactory interface { + // AutoWire will wire the bean. + AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error +} \ No newline at end of file diff --git a/pkg/bean/metadata.go b/pkg/bean/metadata.go new file mode 100644 index 00000000..8c423692 --- /dev/null +++ b/pkg/bean/metadata.go @@ -0,0 +1,28 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// BeanMetadata, in other words, bean's config. +// it could be read from config file +type BeanMetadata struct { + // Fields: field name => field metadata + Fields map[string]*FieldMetadata +} + +// FieldMetadata contains metadata +type FieldMetadata struct { + // default value in string format + DftValue string +} diff --git a/pkg/bean/tag_auto_wire_bean_factory.go b/pkg/bean/tag_auto_wire_bean_factory.go new file mode 100644 index 00000000..ea8fd907 --- /dev/null +++ b/pkg/bean/tag_auto_wire_bean_factory.go @@ -0,0 +1,231 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "fmt" + "reflect" + "strconv" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/logs" +) + +const DefaultValueTagKey = "default" + +// TagAutoWireBeanFactory wire the bean based on Fields' tag +// if field's value is "zero value", we will execute injection +// see reflect.Value.IsZero() +// If field's kind is one of(reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice +// reflect.UnsafePointer, reflect.Array, reflect.Uintptr, reflect.Complex64, reflect.Complex128 +// reflect.Ptr, reflect.Struct), +// it will be ignored +type TagAutoWireBeanFactory struct { + // we allow user register their TypeAdapter + Adapters map[string]TypeAdapter + + // FieldTagParser is an extension point which means that you can custom how to read field's metadata from tag + FieldTagParser func(field reflect.StructField) *FieldMetadata +} + +// NewTagAutoWireBeanFactory create an instance of TagAutoWireBeanFactory +// by default, we register Time adapter, the time will be parse by using layout "2006-01-02 15:04:05" +// If you need more adapter, you can implement interface TypeAdapter +func NewTagAutoWireBeanFactory() *TagAutoWireBeanFactory { + return &TagAutoWireBeanFactory{ + Adapters: map[string]TypeAdapter{ + "Time": &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"}, + }, + + FieldTagParser: func(field reflect.StructField) *FieldMetadata { + return &FieldMetadata{ + DftValue: field.Tag.Get(DefaultValueTagKey), + } + }, + } +} + +// AutoWire use value from appCtx to wire the bean, or use default value, or do nothing +func (t *TagAutoWireBeanFactory) AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error { + if bean == nil { + return nil + } + + v := reflect.Indirect(reflect.ValueOf(bean)) + + bm := t.getConfig(v) + + // field name, field metadata + for fn, fm := range bm.Fields { + + fValue := v.FieldByName(fn) + if len(fm.DftValue) == 0 || !t.needInject(fValue) || !fValue.CanSet() { + continue + } + + // handle type adapter + typeName := fValue.Type().Name() + if adapter, ok := t.Adapters[typeName]; ok { + dftValue, err := adapter.DefaultValue(ctx, fm.DftValue) + if err == nil { + fValue.Set(reflect.ValueOf(dftValue)) + continue + } else { + return err + } + } + + switch fValue.Kind() { + case reflect.Bool: + if v, err := strconv.ParseBool(fm.DftValue); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to bool value", + fn, fm.DftValue)) + } else { + fValue.SetBool(v) + continue + } + case reflect.Int: + if err := t.setIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + continue + case reflect.Int8: + if err := t.setIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + case reflect.Int16: + if err := t.setIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int32: + if err := t.setIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int64: + if err := t.setIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint: + if err := t.setUIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + + case reflect.Uint8: + if err := t.setUIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint16: + if err := t.setUIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + case reflect.Uint32: + if err := t.setUIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint64: + if err := t.setUIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Float32: + if err := t.setFloatXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + case reflect.Float64: + if err := t.setFloatXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.String: + fValue.SetString(fm.DftValue) + continue + + // case reflect.Ptr: + // case reflect.Struct: + default: + logs.Warn("this field[%s] has default setting, but we don't support this type: %s", + fn, fValue.Kind().String()) + } + } + return nil +} + +func (t *TagAutoWireBeanFactory) setFloatXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseFloat(dftValue, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to float%d value", + fn, dftValue, bitSize)) + } else { + fv.SetFloat(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setUIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseUint(dftValue, 10, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to uint%d value", + fn, dftValue, bitSize)) + } else { + fv.SetUint(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseInt(dftValue, 10, bitSize); err != nil { + return errors.WithMessage(err, + fmt.Sprintf("can not convert the field[%s]'s default value[%s] to int%d value", + fn, dftValue, bitSize)) + } else { + fv.SetInt(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) needInject(fValue reflect.Value) bool { + return fValue.IsZero() +} + +// getConfig never return nil +func (t *TagAutoWireBeanFactory) getConfig(beanValue reflect.Value) *BeanMetadata { + fms := make(map[string]*FieldMetadata, beanValue.NumField()) + for i := 0; i < beanValue.NumField(); i++ { + // f => StructField + f := beanValue.Type().Field(i) + fms[f.Name] = t.FieldTagParser(f) + } + return &BeanMetadata{ + Fields: fms, + } +} diff --git a/pkg/bean/tag_auto_wire_bean_factory_test.go b/pkg/bean/tag_auto_wire_bean_factory_test.go new file mode 100644 index 00000000..2d83c537 --- /dev/null +++ b/pkg/bean/tag_auto_wire_bean_factory_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) { + factory := NewTagAutoWireBeanFactory() + bm := &ComplicateStruct{} + err := factory.AutoWire(context.Background(), nil, bm) + assert.Nil(t, err) + assert.Equal(t, 12, bm.IntValue) + assert.Equal(t, "hello, strValue", bm.StrValue) + + assert.Equal(t, int8(8), bm.Int8Value) + assert.Equal(t, int16(16), bm.Int16Value) + assert.Equal(t, int32(32), bm.Int32Value) + assert.Equal(t, int64(64), bm.Int64Value) + + assert.Equal(t, uint(13), bm.UintValue) + assert.Equal(t, uint8(88), bm.Uint8Value) + assert.Equal(t, uint16(1616), bm.Uint16Value) + assert.Equal(t, uint32(3232), bm.Uint32Value) + assert.Equal(t, uint64(6464), bm.Uint64Value) + + assert.Equal(t, float32(32.32), bm.Float32Value) + assert.Equal(t, float64(64.64), bm.Float64Value) + + assert.True(t, bm.BoolValue) + assert.Equal(t, 0, bm.ignoreInt) + + assert.NotNil(t, bm.TimeValue) +} + +type ComplicateStruct struct { + IntValue int `default:"12"` + StrValue string `default:"hello, strValue"` + Int8Value int8 `default:"8"` + Int16Value int16 `default:"16"` + Int32Value int32 `default:"32"` + Int64Value int64 `default:"64"` + + UintValue uint `default:"13"` + Uint8Value uint8 `default:"88"` + Uint16Value uint16 `default:"1616"` + Uint32Value uint32 `default:"3232"` + Uint64Value uint64 `default:"6464"` + + Float32Value float32 `default:"32.32"` + Float64Value float64 `default:"64.64"` + + BoolValue bool `default:"true"` + + ignoreInt int `default:"11"` + + TimeValue time.Time `default:"2018-02-03 12:13:14.000"` +} diff --git a/pkg/bean/time_type_adapter.go b/pkg/bean/time_type_adapter.go new file mode 100644 index 00000000..846eb694 --- /dev/null +++ b/pkg/bean/time_type_adapter.go @@ -0,0 +1,36 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "time" + +) + +// TimeTypeAdapter process the time.Time +type TimeTypeAdapter struct { + Layout string +} + +// DefaultValue parse the DftValue to time.Time +// and if the DftValue == now +// time.Now() is returned +func (t *TimeTypeAdapter) DefaultValue(ctx context.Context, dftValue string) (interface{}, error) { + if dftValue == "now"{ + return time.Now(), nil + } + return time.Parse(t.Layout, dftValue) +} diff --git a/pkg/bean/time_type_adapter_test.go b/pkg/bean/time_type_adapter_test.go new file mode 100644 index 00000000..9c097048 --- /dev/null +++ b/pkg/bean/time_type_adapter_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTimeTypeAdapter_DefaultValue(t *testing.T) { + typeAdapter := &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"} + tm, err := typeAdapter.DefaultValue(context.Background(), "2018-02-03 12:34:11") + assert.Nil(t, err) + assert.NotNil(t, tm) +} diff --git a/pkg/bean/type_adapter.go b/pkg/bean/type_adapter.go new file mode 100644 index 00000000..ba675b64 --- /dev/null +++ b/pkg/bean/type_adapter.go @@ -0,0 +1,26 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// TypeAdapter is an abstraction that define some behavior of target type +// usually, we don't use this to support basic type since golang has many restriction for basic types +// This is an important extension point +type TypeAdapter interface { + DefaultValue(ctx context.Context, dftValue string) (interface{}, error) +} diff --git a/pkg/orm/filter/bean/default_value_filter.go b/pkg/orm/filter/bean/default_value_filter.go new file mode 100644 index 00000000..80aef43d --- /dev/null +++ b/pkg/orm/filter/bean/default_value_filter.go @@ -0,0 +1,136 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "reflect" + "strings" + + "github.com/astaxie/beego/pkg/bean" + "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/orm" +) + +// DefaultValueFilterChainBuilder only works for InsertXXX method, +// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods. +// so we won't handle those two methods unless you set includeInsertOrUpdate to true +// And if the element is not pointer, this filter doesn't work +type DefaultValueFilterChainBuilder struct { + factory bean.AutoWireBeanFactory + compatibleWithOldStyle bool + + // only the includeInsertOrUpdate is true, this filter will handle those two methods + includeInsertOrUpdate bool +} + +// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder +// In beego v1.x, the default value config looks like orm:default(xxxx) +// But the default value in 2.x is default:xxx +// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle +func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter, + includeInsertOrUpdate bool, + compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder { + factory := bean.NewTagAutoWireBeanFactory() + + if compatibleWithOldStyle { + newParser := factory.FieldTagParser + factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata { + if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" { + return newParser(field) + } else { + res := &bean.FieldMetadata{} + ormMeta := field.Tag.Get("orm") + ormMetaParts := strings.Split(ormMeta, ";") + for _, p := range ormMetaParts { + if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") { + res.DftValue = p[8 : len(p)-1] + } + } + return res + } + } + } + + for k, v := range typeAdapters { + factory.Adapters[k] = v + } + + return &DefaultValueFilterChainBuilder{ + factory: factory, + compatibleWithOldStyle: compatibleWithOldStyle, + includeInsertOrUpdate: includeInsertOrUpdate, + } +} + +func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) { + switch inv.Method { + case "Insert", "InsertWithCtx": + d.handleInsert(ctx, inv) + break + case "InsertOrUpdate", "InsertOrUpdateWithCtx": + d.handleInsertOrUpdate(ctx, inv) + break + case "InsertMulti", "InsertMultiWithCtx": + d.handleInsertMulti(ctx, inv) + break + } + next(ctx, inv) + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) { + d.setDefaultValue(ctx, inv.Args[0]) +} + +func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) { + if d.includeInsertOrUpdate { + ins := inv.Args[0] + if ins == nil { + return + } + + pkName := inv.GetPkFieldName() + pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName) + + if pkField.IsZero() { + d.setDefaultValue(ctx, ins) + } + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) { + mds := inv.Args[1] + + if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice { + // do nothing + return + } + + mdsArr := reflect.Indirect(reflect.ValueOf(mds)) + for i := 0; i < mdsArr.Len(); i++ { + d.setDefaultValue(ctx, mdsArr.Index(i).Interface()) + } + logs.Warn("%v", mdsArr.Index(0).Interface()) +} + +func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) { + err := d.factory.AutoWire(ctx, nil, ins) + if err != nil { + logs.Error("try to wire the bean for orm.Insert failed. "+ + "the default value is not set: %v, ", err) + } +} diff --git a/pkg/orm/filter/bean/default_value_filter_test.go b/pkg/orm/filter/bean/default_value_filter_test.go new file mode 100644 index 00000000..6b038f27 --- /dev/null +++ b/pkg/orm/filter/bean/default_value_filter_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/orm" +) + +func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { + builder := NewDefaultValueFilterChainBuilder(nil, true, true) + o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain) + + // test insert + entity := &DefaultValueTestEntity{} + _, _ = o.Insert(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + assert.Equal(t, 0, entity.AgeIgnore) + + // test InsertOrUpdate + entity = &DefaultValueTestEntity{} + orm.RegisterModel(entity) + + + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + + // we won't set the default value because we find the pk is not Zero value + entity.Id = 3 + entity.AgeInOldStyle = 0 + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 0, entity.AgeInOldStyle) + + entity = &DefaultValueTestEntity{} + + // the entity is not array, it will be ignored + _, _ = o.InsertMulti(3, entity) + assert.Equal(t, 0, entity.Age) + assert.Equal(t, 0, entity.AgeInOldStyle) + + _, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity}) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + +} + +type defaultValueTestOrm struct { + orm.DoNothingOrm +} + +type DefaultValueTestEntity struct { + Id int`orm:pk` + Age int `default:"12"` + AgeInOldStyle int `orm:"default(13);bee()"` + AgeIgnore int +} diff --git a/pkg/orm/invocation.go b/pkg/orm/invocation.go index 1c9fee09..586ec573 100644 --- a/pkg/orm/invocation.go +++ b/pkg/orm/invocation.go @@ -46,3 +46,12 @@ func (inv *Invocation) GetTableName() string { func (inv *Invocation) execute() { inv.f() } + +// GetPkFieldName return the primary key of this table +// if not found, "" is returned +func (inv *Invocation) GetPkFieldName() string { + if inv.mi.fields.pk != nil { + return inv.mi.fields.pk.name + } + return "" +}