diff --git a/orm/db.go b/orm/db.go index 00945e43..decc8fc2 100644 --- a/orm/db.go +++ b/orm/db.go @@ -103,15 +103,36 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { switch fi.fieldType { case TypeBooleanField: - value = field.Bool() - case TypeCharField, TypeTextField: - value = field.String() - case TypeFloatField, TypeDecimalField: - vu := field.Interface() - if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() + if nb, ok := field.Interface().(sql.NullBool); ok { + value = nil + if nb.Valid { + value = nb.Bool + } } else { - value = field.Float() + value = field.Bool() + } + case TypeCharField, TypeTextField: + if ns, ok := field.Interface().(sql.NullString); ok { + value = nil + if ns.Valid { + value = ns.String + } + } else { + value = field.String() + } + case TypeFloatField, TypeDecimalField: + if nf, ok := field.Interface().(sql.NullFloat64); ok { + value = nil + if nf.Valid { + value = nf.Float64 + } + } else { + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } } case TypeDateField, TypeDateTimeField: value = field.Interface() @@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val case fi.fieldType&IsPostiveIntegerField > 0: value = field.Uint() case fi.fieldType&IsIntegerField > 0: - value = field.Int() + if ni, ok := field.Interface().(sql.NullInt64); ok { + value = nil + if ni.Valid { + value = ni.Int64 + } + } else { + value = field.Int() + } case fi.fieldType&IsRelField > 0: if field.IsNil() { value = nil @@ -1122,17 +1150,37 @@ setValue: switch { case fieldType == TypeBooleanField: if isNative { - if value == nil { - value = false + if nb, ok := field.Interface().(sql.NullBool); ok { + if value == nil { + nb.Valid = false + } else { + nb.Bool = value.(bool) + nb.Valid = true + } + field.Set(reflect.ValueOf(nb)) + } else { + if value == nil { + value = false + } + field.SetBool(value.(bool)) } - field.SetBool(value.(bool)) } case fieldType == TypeCharField || fieldType == TypeTextField: if isNative { - if value == nil { - value = "" + if ns, ok := field.Interface().(sql.NullString); ok { + if value == nil { + ns.Valid = false + } else { + ns.String = value.(string) + ns.Valid = true + } + field.Set(reflect.ValueOf(ns)) + } else { + if value == nil { + value = "" + } + field.SetString(value.(string)) } - field.SetString(value.(string)) } case fieldType == TypeDateField || fieldType == TypeDateTimeField: if isNative { @@ -1151,18 +1199,39 @@ setValue: } } else { if isNative { - if value == nil { - value = int64(0) + if ni, ok := field.Interface().(sql.NullInt64); ok { + if value == nil { + ni.Valid = false + } else { + ni.Int64 = value.(int64) + ni.Valid = true + } + field.Set(reflect.ValueOf(ni)) + } else { + if value == nil { + value = int64(0) + } + field.SetInt(value.(int64)) } - field.SetInt(value.(int64)) } } case fieldType == TypeFloatField || fieldType == TypeDecimalField: if isNative { - if value == nil { - value = float64(0) + if nf, ok := field.Interface().(sql.NullFloat64); ok { + if value == nil { + nf.Valid = false + } else { + nf.Float64 = value.(float64) + nf.Valid = true + } + field.Set(reflect.ValueOf(nf)) + } else { + + if value == nil { + value = float64(0) + } + field.SetFloat(value.(float64)) } - field.SetFloat(value.(float64)) } case fieldType&IsRelField > 0: if value != nil { diff --git a/orm/models_test.go b/orm/models_test.go index 706f04dc..168c091a 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -1,6 +1,7 @@ package orm import ( + "database/sql" "encoding/json" "fmt" "os" @@ -116,27 +117,31 @@ type Data struct { } type DataNull struct { - Id int - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)""` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` + Id int + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)""` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` } // only for mysql @@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm #### Sqlite3 -touch /path/to/orm_test.db export ORM_DRIVER=sqlite3 -export ORM_SOURCE=/path/to/orm_test.db +export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/astaxie/beego/orm diff --git a/orm/models_utils.go b/orm/models_utils.go index 1466a724..759093ef 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -1,6 +1,7 @@ package orm import ( + "database/sql" "fmt" "reflect" "strings" @@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col // return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) - switch elm.Kind() { - case reflect.Int8: + switch elm.Interface().(type) { + case int8: ft = TypeBitField - case reflect.Int16: + case int16: ft = TypeSmallIntegerField - case reflect.Int32, reflect.Int: + case int32, int: ft = TypeIntegerField - case reflect.Int64: + case int64, sql.NullInt64: ft = TypeBigIntegerField - case reflect.Uint8: + case uint8: ft = TypePositiveBitField - case reflect.Uint16: + case uint16: ft = TypePositiveSmallIntegerField - case reflect.Uint32, reflect.Uint: + case uint32, uint: ft = TypePositiveIntegerField - case reflect.Uint64: + case uint64: ft = TypePositiveBigIntegerField - case reflect.Float32, reflect.Float64: + case float32, float64, sql.NullFloat64: ft = TypeFloatField - case reflect.Bool: + case bool, sql.NullBool: ft = TypeBooleanField - case reflect.String: + case string, sql.NullString: ft = TypeCharField - case reflect.Invalid: default: if elm.CanInterface() { if _, ok := elm.Interface().(time.Time); ok { diff --git a/orm/orm_test.go b/orm/orm_test.go index c951d5ca..060cd65c 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -2,6 +2,7 @@ package orm import ( "bytes" + "database/sql" "fmt" "io/ioutil" "os" @@ -258,12 +259,45 @@ func TestNullDataTypes(t *testing.T) { err = dORM.Read(&d) throwFail(t, err) + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() throwFail(t, err) d = DataNull{Id: 2} err = dORM.Read(&d) throwFail(t, err) + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{"test", true}, + NullBool: sql.NullBool{true, true}, + NullInt64: sql.NullInt64{42, true}, + NullFloat64: sql.NullFloat64{42.42, true}, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{Id: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) } func TestCRUD(t *testing.T) { @@ -1646,10 +1680,10 @@ func TestTransaction(t *testing.T) { func TestReadOrCreate(t *testing.T) { u := &User{ UserName: "Kyle", - Email: "kylemcc@gmail.com", + Email: "kylemcc@gmail.com", Password: "other_pass", - Status: 7, - IsStaff: false, + Status: 7, + IsStaff: false, IsActive: true, }