1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-29 14:01:29 +00:00

Merge pull request #762 from pdf/pointer_field_support

Add support for basic type pointer fields
This commit is contained in:
astaxie 2014-08-22 13:41:58 +08:00
commit bf429a3a20
4 changed files with 247 additions and 24 deletions

102
orm/db.go
View File

@ -122,6 +122,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if nb.Valid { if nb.Valid {
value = nb.Bool value = nb.Bool
} }
} else if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
} else {
value = field.Elem().Bool()
}
} else { } else {
value = field.Bool() value = field.Bool()
} }
@ -131,6 +137,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if ns.Valid { if ns.Valid {
value = ns.String value = ns.String
} }
} else if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
} else {
value = field.Elem().String()
}
} else { } else {
value = field.String() value = field.String()
} }
@ -140,6 +152,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if nf.Valid { if nf.Valid {
value = nf.Float64 value = nf.Float64
} }
} else if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
} else {
value = field.Elem().Float()
}
} else { } else {
vu := field.Interface() vu := field.Interface()
if _, ok := vu.(float32); ok { if _, ok := vu.(float32); ok {
@ -161,13 +179,27 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
default: default:
switch { switch {
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPostiveIntegerField > 0:
if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
} else {
value = field.Elem().Uint()
}
} else {
value = field.Uint() value = field.Uint()
}
case fi.fieldType&IsIntegerField > 0: case fi.fieldType&IsIntegerField > 0:
if ni, ok := field.Interface().(sql.NullInt64); ok { if ni, ok := field.Interface().(sql.NullInt64); ok {
value = nil value = nil
if ni.Valid { if ni.Valid {
value = ni.Int64 value = ni.Int64
} }
} else if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
} else {
value = field.Elem().Int()
}
} else { } else {
value = field.Int() value = field.Int()
} }
@ -1177,6 +1209,11 @@ setValue:
nb.Valid = true nb.Valid = true
} }
field.Set(reflect.ValueOf(nb)) field.Set(reflect.ValueOf(nb))
} else if field.Kind() == reflect.Ptr {
if value != nil {
v := value.(bool)
field.Set(reflect.ValueOf(&v))
}
} else { } else {
if value == nil { if value == nil {
value = false value = false
@ -1194,6 +1231,11 @@ setValue:
ns.Valid = true ns.Valid = true
} }
field.Set(reflect.ValueOf(ns)) field.Set(reflect.ValueOf(ns))
} else if field.Kind() == reflect.Ptr {
if value != nil {
v := value.(string)
field.Set(reflect.ValueOf(&v))
}
} else { } else {
if value == nil { if value == nil {
value = "" value = ""
@ -1208,6 +1250,56 @@ setValue:
} }
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
} }
case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr:
if value != nil {
v := uint8(value.(uint64))
field.Set(reflect.ValueOf(&v))
}
case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
v := uint16(value.(uint64))
field.Set(reflect.ValueOf(&v))
}
case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
if field.Type() == reflect.TypeOf(new(uint)) {
v := uint(value.(uint64))
field.Set(reflect.ValueOf(&v))
} else {
v := uint32(value.(uint64))
field.Set(reflect.ValueOf(&v))
}
}
case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
v := value.(uint64)
field.Set(reflect.ValueOf(&v))
}
case fieldType == TypeBitField && field.Kind() == reflect.Ptr:
if value != nil {
v := int8(value.(int64))
field.Set(reflect.ValueOf(&v))
}
case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
v := int16(value.(int64))
field.Set(reflect.ValueOf(&v))
}
case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
if field.Type() == reflect.TypeOf(new(int)) {
v := int(value.(int64))
field.Set(reflect.ValueOf(&v))
} else {
v := int32(value.(int64))
field.Set(reflect.ValueOf(&v))
}
}
case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr:
if value != nil {
v := value.(int64)
field.Set(reflect.ValueOf(&v))
}
case fieldType&IsIntegerField > 0: case fieldType&IsIntegerField > 0:
if fieldType&IsPostiveIntegerField > 0 { if fieldType&IsPostiveIntegerField > 0 {
if isNative { if isNative {
@ -1244,6 +1336,16 @@ setValue:
nf.Valid = true nf.Valid = true
} }
field.Set(reflect.ValueOf(nf)) field.Set(reflect.ValueOf(nf))
} else if field.Kind() == reflect.Ptr {
if value != nil {
if field.Type() == reflect.TypeOf(new(float32)) {
v := float32(value.(float64))
field.Set(reflect.ValueOf(&v))
} else {
v := value.(float64)
field.Set(reflect.ValueOf(&v))
}
}
} else { } else {
if value == nil { if value == nil {

View File

@ -155,6 +155,24 @@ type DataNull struct {
NullBool sql.NullBool `orm:"null"` NullBool sql.NullBool `orm:"null"`
NullFloat64 sql.NullFloat64 `orm:"null"` NullFloat64 sql.NullFloat64 `orm:"null"`
NullInt64 sql.NullInt64 `orm:"null"` NullInt64 sql.NullInt64 `orm:"null"`
BooleanPtr *bool `orm:"null"`
CharPtr *string `orm:"null;size(50)"`
TextPtr *string `orm:"null;type(text)"`
BytePtr *byte `orm:"null"`
RunePtr *rune `orm:"null"`
IntPtr *int `orm:"null"`
Int8Ptr *int8 `orm:"null"`
Int16Ptr *int16 `orm:"null"`
Int32Ptr *int32 `orm:"null"`
Int64Ptr *int64 `orm:"null"`
UintPtr *uint `orm:"null"`
Uint8Ptr *uint8 `orm:"null"`
Uint16Ptr *uint16 `orm:"null"`
Uint32Ptr *uint32 `orm:"null"`
Uint64Ptr *uint64 `orm:"null"`
Float32Ptr *float32 `orm:"null"`
Float64Ptr *float64 `orm:"null"`
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
} }
type String string type String string

View File

@ -111,6 +111,33 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value // return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
switch val.Type() {
case reflect.TypeOf(new(int8)):
ft = TypeBitField
case reflect.TypeOf(new(int16)):
ft = TypeSmallIntegerField
case reflect.TypeOf(new(int32)),
reflect.TypeOf(new(int)):
ft = TypeIntegerField
case reflect.TypeOf(new(int64)):
ft = TypeBigIntegerField
case reflect.TypeOf(new(uint8)):
ft = TypePositiveBitField
case reflect.TypeOf(new(uint16)):
ft = TypePositiveSmallIntegerField
case reflect.TypeOf(new(uint32)),
reflect.TypeOf(new(uint)):
ft = TypePositiveIntegerField
case reflect.TypeOf(new(uint64)):
ft = TypePositiveBigIntegerField
case reflect.TypeOf(new(float32)),
reflect.TypeOf(new(float64)):
ft = TypeFloatField
case reflect.TypeOf(new(bool)):
ft = TypeBooleanField
case reflect.TypeOf(new(string)):
ft = TypeCharField
default:
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
case reflect.Int8: case reflect.Int8:
@ -152,6 +179,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
ft = TypeDateTimeField ft = TypeDateTimeField
} }
} }
}
if ft&IsFieldType == 0 { if ft&IsFieldType == 0 {
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
} }

View File

@ -287,6 +287,25 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.NullInt64.Valid, false)) throwFail(t, AssertIs(d.NullInt64.Valid, false))
throwFail(t, AssertIs(d.NullFloat64.Valid, false)) throwFail(t, AssertIs(d.NullFloat64.Valid, false))
throwFail(t, AssertIs(d.BooleanPtr, nil))
throwFail(t, AssertIs(d.CharPtr, nil))
throwFail(t, AssertIs(d.TextPtr, nil))
throwFail(t, AssertIs(d.BytePtr, nil))
throwFail(t, AssertIs(d.RunePtr, nil))
throwFail(t, AssertIs(d.IntPtr, nil))
throwFail(t, AssertIs(d.Int8Ptr, nil))
throwFail(t, AssertIs(d.Int16Ptr, nil))
throwFail(t, AssertIs(d.Int32Ptr, nil))
throwFail(t, AssertIs(d.Int64Ptr, nil))
throwFail(t, AssertIs(d.UintPtr, nil))
throwFail(t, AssertIs(d.Uint8Ptr, nil))
throwFail(t, AssertIs(d.Uint16Ptr, nil))
throwFail(t, AssertIs(d.Uint32Ptr, nil))
throwFail(t, AssertIs(d.Uint64Ptr, nil))
throwFail(t, AssertIs(d.Float32Ptr, nil))
throwFail(t, AssertIs(d.Float64Ptr, nil))
throwFail(t, AssertIs(d.DecimalPtr, nil))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
@ -294,12 +313,49 @@ func TestNullDataTypes(t *testing.T) {
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
booleanPtr := true
charPtr := string("test")
textPtr := string("test")
bytePtr := byte('t')
runePtr := rune('t')
intPtr := int(42)
int8Ptr := int8(42)
int16Ptr := int16(42)
int32Ptr := int32(42)
int64Ptr := int64(42)
uintPtr := uint(42)
uint8Ptr := uint8(42)
uint16Ptr := uint16(42)
uint32Ptr := uint32(42)
uint64Ptr := uint64(42)
float32Ptr := float32(42.0)
float64Ptr := float64(42.0)
decimalPtr := float64(42.0)
d = DataNull{ d = DataNull{
DateTime: time.Now(), DateTime: time.Now(),
NullString: sql.NullString{String: "test", Valid: true}, NullString: sql.NullString{String: "test", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true}, NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
BooleanPtr: &booleanPtr,
CharPtr: &charPtr,
TextPtr: &textPtr,
BytePtr: &bytePtr,
RunePtr: &runePtr,
IntPtr: &intPtr,
Int8Ptr: &int8Ptr,
Int16Ptr: &int16Ptr,
Int32Ptr: &int32Ptr,
Int64Ptr: &int64Ptr,
UintPtr: &uintPtr,
Uint8Ptr: &uint8Ptr,
Uint16Ptr: &uint16Ptr,
Uint32Ptr: &uint32Ptr,
Uint64Ptr: &uint64Ptr,
Float32Ptr: &float32Ptr,
Float64Ptr: &float64Ptr,
DecimalPtr: &decimalPtr,
} }
id, err = dORM.Insert(&d) id, err = dORM.Insert(&d)
@ -321,6 +377,25 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.NullFloat64.Valid, true)) throwFail(t, AssertIs(d.NullFloat64.Valid, true))
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr))
throwFail(t, AssertIs(*d.CharPtr, charPtr))
throwFail(t, AssertIs(*d.TextPtr, textPtr))
throwFail(t, AssertIs(*d.BytePtr, bytePtr))
throwFail(t, AssertIs(*d.RunePtr, runePtr))
throwFail(t, AssertIs(*d.IntPtr, intPtr))
throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr))
throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr))
throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr))
throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr))
throwFail(t, AssertIs(*d.UintPtr, uintPtr))
throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr))
throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr))
throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr))
throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr))
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
} }
func TestDataCustomTypes(t *testing.T) { func TestDataCustomTypes(t *testing.T) {