diff --git a/orm/db.go b/orm/db.go index 2addff22..10f65fee 100644 --- a/orm/db.go +++ b/orm/db.go @@ -122,6 +122,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val if nb.Valid { value = nb.Bool } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Bool() + } } else { value = field.Bool() } @@ -131,6 +137,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val if ns.Valid { value = ns.String } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().String() + } } else { value = field.String() } @@ -140,6 +152,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val if nf.Valid { value = nf.Float64 } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Float() + } } else { vu := field.Interface() if _, ok := vu.(float32); ok { @@ -161,13 +179,27 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val default: switch { case fi.fieldType&IsPostiveIntegerField > 0: - value = field.Uint() + if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Uint() + } + } else { + value = field.Uint() + } case fi.fieldType&IsIntegerField > 0: if ni, ok := field.Interface().(sql.NullInt64); ok { value = nil if ni.Valid { value = ni.Int64 } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Int() + } } else { value = field.Int() } @@ -1177,6 +1209,11 @@ setValue: nb.Valid = true } field.Set(reflect.ValueOf(nb)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(bool) + field.Set(reflect.ValueOf(&v)) + } } else { if value == nil { value = false @@ -1194,6 +1231,11 @@ setValue: ns.Valid = true } field.Set(reflect.ValueOf(ns)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(string) + field.Set(reflect.ValueOf(&v)) + } } else { if value == nil { value = "" @@ -1208,6 +1250,56 @@ setValue: } field.Set(reflect.ValueOf(value)) } + case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint8(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint16(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(uint)) { + v := uint(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := uint32(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(uint64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := int8(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := int16(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(int)) { + v := int(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := int32(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(int64) + field.Set(reflect.ValueOf(&v)) + } case fieldType&IsIntegerField > 0: if fieldType&IsPostiveIntegerField > 0 { if isNative { @@ -1244,6 +1336,16 @@ setValue: nf.Valid = true } field.Set(reflect.ValueOf(nf)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + if field.Type() == reflect.TypeOf(new(float32)) { + v := float32(value.(float64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := value.(float64) + field.Set(reflect.ValueOf(&v)) + } + } } else { if value == nil { diff --git a/orm/models_test.go b/orm/models_test.go index fe78c45f..1a92ef5d 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -155,6 +155,24 @@ type DataNull struct { NullBool sql.NullBool `orm:"null"` NullFloat64 sql.NullFloat64 `orm:"null"` NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` } type String string diff --git a/orm/models_utils.go b/orm/models_utils.go index ffde40ea..ec11d516 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -111,45 +111,73 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col // return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { - elm := reflect.Indirect(val) - switch elm.Kind() { - case reflect.Int8: + switch val.Type() { + case reflect.TypeOf(new(int8)): ft = TypeBitField - case reflect.Int16: + case reflect.TypeOf(new(int16)): ft = TypeSmallIntegerField - case reflect.Int32, reflect.Int: + case reflect.TypeOf(new(int32)), + reflect.TypeOf(new(int)): ft = TypeIntegerField - case reflect.Int64: + case reflect.TypeOf(new(int64)): ft = TypeBigIntegerField - case reflect.Uint8: + case reflect.TypeOf(new(uint8)): ft = TypePositiveBitField - case reflect.Uint16: + case reflect.TypeOf(new(uint16)): ft = TypePositiveSmallIntegerField - case reflect.Uint32, reflect.Uint: + case reflect.TypeOf(new(uint32)), + reflect.TypeOf(new(uint)): ft = TypePositiveIntegerField - case reflect.Uint64: + case reflect.TypeOf(new(uint64)): ft = TypePositiveBigIntegerField - case reflect.Float32, reflect.Float64: + case reflect.TypeOf(new(float32)), + reflect.TypeOf(new(float64)): ft = TypeFloatField - case reflect.Bool: + case reflect.TypeOf(new(bool)): ft = TypeBooleanField - case reflect.String: + case reflect.TypeOf(new(string)): ft = TypeCharField default: - if elm.Interface() == nil { - panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) - } - switch elm.Interface().(type) { - case sql.NullInt64: + elm := reflect.Indirect(val) + switch elm.Kind() { + case reflect.Int8: + ft = TypeBitField + case reflect.Int16: + ft = TypeSmallIntegerField + case reflect.Int32, reflect.Int: + ft = TypeIntegerField + case reflect.Int64: ft = TypeBigIntegerField - case sql.NullFloat64: + case reflect.Uint8: + ft = TypePositiveBitField + case reflect.Uint16: + ft = TypePositiveSmallIntegerField + case reflect.Uint32, reflect.Uint: + ft = TypePositiveIntegerField + case reflect.Uint64: + ft = TypePositiveBigIntegerField + case reflect.Float32, reflect.Float64: ft = TypeFloatField - case sql.NullBool: + case reflect.Bool: ft = TypeBooleanField - case sql.NullString: + case reflect.String: ft = TypeCharField - case time.Time: - ft = TypeDateTimeField + default: + if elm.Interface() == nil { + panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) + } + switch elm.Interface().(type) { + case sql.NullInt64: + ft = TypeBigIntegerField + case sql.NullFloat64: + ft = TypeFloatField + case sql.NullBool: + ft = TypeBooleanField + case sql.NullString: + ft = TypeCharField + case time.Time: + ft = TypeDateTimeField + } } } if ft&IsFieldType == 0 { diff --git a/orm/orm_test.go b/orm/orm_test.go index e7420bcb..e1c8e0f0 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -287,6 +287,25 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(d.NullInt64.Valid, false)) throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() throwFail(t, err) @@ -294,12 +313,49 @@ func TestNullDataTypes(t *testing.T) { err = dORM.Read(&d) throwFail(t, err) + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + d = DataNull{ DateTime: time.Now(), NullString: sql.NullString{String: "test", Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true}, NullInt64: sql.NullInt64{Int64: 42, Valid: true}, NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, } id, err = dORM.Insert(&d) @@ -321,6 +377,25 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(d.NullFloat64.Valid, true)) throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) } func TestDataCustomTypes(t *testing.T) {