From 95e67ba2c2a701b7be3c48c8e9076a4ce14e12fa Mon Sep 17 00:00:00 2001 From: slene Date: Thu, 13 Mar 2014 23:31:47 +0800 Subject: [PATCH] orm now support custom builtin types as model struct field or query args fix #489 --- orm/db_utils.go | 112 ++++++++++++++++++++++++-------------------- orm/models_test.go | 39 +++++++++++++++ orm/models_utils.go | 39 ++++++++------- orm/orm_test.go | 45 ++++++++++++++++-- 4 files changed, 165 insertions(+), 70 deletions(-) diff --git a/orm/db_utils.go b/orm/db_utils.go index 34de8186..3c8d2d23 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -51,9 +51,16 @@ outFor: continue } - switch v := arg.(type) { - case []byte: - case string: + kind := val.Kind() + if kind == reflect.Ptr { + val = val.Elem() + kind = val.Kind() + arg = val.Interface() + } + + switch kind { + case reflect.String: + v := val.String() if fi != nil { if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { var t time.Time @@ -78,61 +85,66 @@ outFor: } } arg = v - case time.Time: - if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(format_Date) - } else { - arg = v.In(tz).Format(format_DateTime) - } - default: - kind := val.Kind() - switch kind { - case reflect.Slice, reflect.Array: - - var args []interface{} - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - - var vu interface{} - if v.CanInterface() { - vu = v.Interface() - } - - if vu == nil { - continue - } - - args = append(args, vu) - } - - if len(args) > 0 { - p := getFlatParams(fi, args, tz) - params = append(params, p...) - } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg = val.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + arg = val.Uint() + case reflect.Float32: + arg, _ = StrTo(ToStr(arg)).Float64() + case reflect.Float64: + arg = val.Float() + case reflect.Bool: + arg = val.Bool() + case reflect.Slice, reflect.Array: + if _, ok := arg.([]byte); ok { continue outFor + } - case reflect.Ptr, reflect.Struct: - ind := reflect.Indirect(val) + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) - if ind.Kind() == reflect.Struct { - typ := ind.Type() - name := getFullName(typ) - var value interface{} - if mmi, ok := modelCache.getByFN(name); ok { - if _, vu, exist := getExistPk(mmi, ind); exist { - value = vu - } - } - arg = value + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } - if arg == nil { - panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) - } + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := getFlatParams(fi, args, tz) + params = append(params, p...) + } + continue outFor + case reflect.Struct: + if v, ok := arg.(time.Time); ok { + if fi != nil && fi.fieldType == TypeDateField { + arg = v.In(tz).Format(format_Date) } else { - arg = ind.Interface() + arg = v.In(tz).Format(format_DateTime) + } + } else { + typ := val.Type() + name := getFullName(typ) + var value interface{} + if mmi, ok := modelCache.getByFN(name); ok { + if _, vu, exist := getExistPk(mmi, val); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) } } } + params = append(params, arg) } return diff --git a/orm/models_test.go b/orm/models_test.go index 168c091a..8564dcb8 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -144,6 +144,45 @@ type DataNull struct { NullInt64 sql.NullInt64 `orm:"null"` } +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + Id int + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + // only for mysql type UserBig struct { Id uint64 diff --git a/orm/models_utils.go b/orm/models_utils.go index 759093ef..f6cb14ec 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -99,34 +99,41 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col // return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) - switch elm.Interface().(type) { - case int8: + switch elm.Kind() { + case reflect.Int8: ft = TypeBitField - case int16: + case reflect.Int16: ft = TypeSmallIntegerField - case int32, int: + case reflect.Int32, reflect.Int: ft = TypeIntegerField - case int64, sql.NullInt64: + case reflect.Int64: ft = TypeBigIntegerField - case uint8: + case reflect.Uint8: ft = TypePositiveBitField - case uint16: + case reflect.Uint16: ft = TypePositiveSmallIntegerField - case uint32, uint: + case reflect.Uint32, reflect.Uint: ft = TypePositiveIntegerField - case uint64: + case reflect.Uint64: ft = TypePositiveBigIntegerField - case float32, float64, sql.NullFloat64: + case reflect.Float32, reflect.Float64: ft = TypeFloatField - case bool, sql.NullBool: + case reflect.Bool: ft = TypeBooleanField - case string, sql.NullString: + case reflect.String: ft = TypeCharField default: - if elm.CanInterface() { - if _, ok := elm.Interface().(time.Time); ok { - ft = TypeDateTimeField - } + switch elm.Interface().(type) { + case sql.NullInt64: + ft = TypeBigIntegerField + case sql.NullFloat64: + ft = TypeFloatField + case sql.NullBool: + ft = TypeBooleanField + case sql.NullString: + ft = TypeCharField + case time.Time: + ft = TypeDateTimeField } } if ft&IsFieldType == 0 { diff --git a/orm/orm_test.go b/orm/orm_test.go index 69f2fc86..0f8ad81a 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -149,7 +149,7 @@ func TestGetDB(t *testing.T) { } func TestSyncDb(t *testing.T) { - RegisterModel(new(Data), new(DataNull)) + RegisterModel(new(Data), new(DataNull), new(DataCustom)) RegisterModel(new(User)) RegisterModel(new(Profile)) RegisterModel(new(Post)) @@ -165,7 +165,7 @@ func TestSyncDb(t *testing.T) { } func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull)) + RegisterModel(new(Data), new(DataNull), new(DataCustom)) RegisterModel(new(User)) RegisterModel(new(Profile)) RegisterModel(new(Post)) @@ -309,6 +309,39 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) } +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range Data_Values { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{Id: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range Data_Values { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + func TestCRUD(t *testing.T) { profile := NewProfile() profile.Age = 30 @@ -562,6 +595,10 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + num, err = qs.Filter("user_name__exact", "slene").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -602,11 +639,11 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 3)) - num, err = qs.Filter("status__lt", 3).Count() + num, err = qs.Filter("status__lt", Uint(3)).Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) - num, err = qs.Filter("status__lte", 3).Count() + num, err = qs.Filter("status__lte", Int(3)).Count() throwFail(t, err) throwFail(t, AssertIs(num, 3))