1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 15:10:55 +00:00

orm now support custom builtin types as model struct field or query args fix #489

This commit is contained in:
slene 2014-03-13 23:31:47 +08:00
parent 769f7c751b
commit 95e67ba2c2
4 changed files with 165 additions and 70 deletions

View File

@ -51,9 +51,16 @@ outFor:
continue continue
} }
switch v := arg.(type) { kind := val.Kind()
case []byte: if kind == reflect.Ptr {
case string: val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time var t time.Time
@ -78,61 +85,66 @@ outFor:
} }
} }
arg = v arg = v
case time.Time: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if fi != nil && fi.fieldType == TypeDateField { arg = val.Int()
arg = v.In(tz).Format(format_Date) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
} else { arg = val.Uint()
arg = v.In(tz).Format(format_DateTime) case reflect.Float32:
} arg, _ = StrTo(ToStr(arg)).Float64()
default: case reflect.Float64:
kind := val.Kind() arg = val.Float()
switch kind { case reflect.Bool:
case reflect.Slice, reflect.Array: arg = val.Bool()
case reflect.Slice, reflect.Array:
var args []interface{} if _, ok := arg.([]byte); ok {
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor continue outFor
}
case reflect.Ptr, reflect.Struct: var args []interface{}
ind := reflect.Indirect(val) for i := 0; i < val.Len(); i++ {
v := val.Index(i)
if ind.Kind() == reflect.Struct { var vu interface{}
typ := ind.Type() if v.CanInterface() {
name := getFullName(typ) vu = v.Interface()
var value interface{} }
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil { if vu == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) continue
} }
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
} else { } else {
arg = ind.Interface() arg = v.In(tz).Format(format_DateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
} }
} }
} }
params = append(params, arg) params = append(params, arg)
} }
return return

View File

@ -144,6 +144,45 @@ type DataNull struct {
NullInt64 sql.NullInt64 `orm:"null"` NullInt64 sql.NullInt64 `orm:"null"`
} }
type String string
type Boolean bool
type Byte byte
type Rune rune
type Int int
type Int8 int8
type Int16 int16
type Int32 int32
type Int64 int64
type Uint uint
type Uint8 uint8
type Uint16 uint16
type Uint32 uint32
type Uint64 uint64
type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
Byte Byte
Rune Rune
Int Int
Int8 Int8
Int16 Int16
Int32 Int32
Int64 Int64
Uint Uint
Uint8 Uint8
Uint16 Uint16
Uint32 Uint32
Uint64 Uint64
Float32 Float32
Float64 Float64
Decimal Float64 `orm:"digits(8);decimals(4)"`
}
// only for mysql // only for mysql
type UserBig struct { type UserBig struct {
Id uint64 Id uint64

View File

@ -99,34 +99,41 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value // return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Interface().(type) { switch elm.Kind() {
case int8: case reflect.Int8:
ft = TypeBitField ft = TypeBitField
case int16: case reflect.Int16:
ft = TypeSmallIntegerField ft = TypeSmallIntegerField
case int32, int: case reflect.Int32, reflect.Int:
ft = TypeIntegerField ft = TypeIntegerField
case int64, sql.NullInt64: case reflect.Int64:
ft = TypeBigIntegerField ft = TypeBigIntegerField
case uint8: case reflect.Uint8:
ft = TypePositiveBitField ft = TypePositiveBitField
case uint16: case reflect.Uint16:
ft = TypePositiveSmallIntegerField ft = TypePositiveSmallIntegerField
case uint32, uint: case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField ft = TypePositiveIntegerField
case uint64: case reflect.Uint64:
ft = TypePositiveBigIntegerField ft = TypePositiveBigIntegerField
case float32, float64, sql.NullFloat64: case reflect.Float32, reflect.Float64:
ft = TypeFloatField ft = TypeFloatField
case bool, sql.NullBool: case reflect.Bool:
ft = TypeBooleanField ft = TypeBooleanField
case string, sql.NullString: case reflect.String:
ft = TypeCharField ft = TypeCharField
default: default:
if elm.CanInterface() { switch elm.Interface().(type) {
if _, ok := elm.Interface().(time.Time); ok { case sql.NullInt64:
ft = TypeDateTimeField ft = TypeBigIntegerField
} case sql.NullFloat64:
ft = TypeFloatField
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeCharField
case time.Time:
ft = TypeDateTimeField
} }
} }
if ft&IsFieldType == 0 { if ft&IsFieldType == 0 {

View File

@ -149,7 +149,7 @@ func TestGetDB(t *testing.T) {
} }
func TestSyncDb(t *testing.T) { func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -165,7 +165,7 @@ func TestSyncDb(t *testing.T) {
} }
func TestRegisterModels(t *testing.T) { func TestRegisterModels(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -309,6 +309,39 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
} }
func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
e.Set(reflect.ValueOf(value).Convert(e.Type()))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
vu := e.Interface()
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
throwFail(t, AssertIs(vu == value, true), value, vu)
}
}
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
profile := NewProfile() profile := NewProfile()
profile.Age = 30 profile.Age = 30
@ -562,6 +595,10 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", String("slene")).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", "slene").Count() num, err = qs.Filter("user_name__exact", "slene").Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -602,11 +639,11 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
num, err = qs.Filter("status__lt", 3).Count() num, err = qs.Filter("status__lt", Uint(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("status__lte", 3).Count() num, err = qs.Filter("status__lte", Int(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))