From 5e4241fc875385899a24641a258e9a278810f339 Mon Sep 17 00:00:00 2001 From: zav8 Date: Thu, 6 Dec 2018 16:07:07 +0800 Subject: [PATCH 1/3] add support for field of type sql.NullXxx in rawSet.setFieldValue() --- orm/orm_raw.go | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index c8ef4398..08efa4e5 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Struct: if value == nil { ind.Set(reflect.Zero(ind.Type())) - - } else if _, ok := ind.Interface().(time.Time); ok { + return + } + switch indi := ind.Interface().(type) { + case time.Time: var str string switch d := value.(type) { case time.Time: @@ -178,6 +180,26 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } } + case sql.NullString: + err := indi.Scan(value) + if err == nil { + ind.Set(reflect.ValueOf(indi)) + } + case sql.NullInt64: + err := indi.Scan(value) + if err == nil { + ind.Set(reflect.ValueOf(indi)) + } + case sql.NullFloat64: + err := indi.Scan(value) + if err == nil { + ind.Set(reflect.ValueOf(indi)) + } + case sql.NullBool: + err := indi.Scan(value) + if err == nil { + ind.Set(reflect.ValueOf(indi)) + } } } } From 6da4a66c20b284e78d4508013310f0872ae81eb8 Mon Sep 17 00:00:00 2001 From: zav8 Date: Thu, 6 Dec 2018 16:09:39 +0800 Subject: [PATCH 2/3] merge switch cases --- orm/orm_raw.go | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 08efa4e5..27651fe4 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -152,7 +152,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { ind.Set(reflect.Zero(ind.Type())) return } - switch indi := ind.Interface().(type) { + switch ind.Interface().(type) { case time.Time: var str string switch d := value.(type) { @@ -180,25 +180,15 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } } - case sql.NullString: - err := indi.Scan(value) - if err == nil { - ind.Set(reflect.ValueOf(indi)) + case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: + indi := reflect.New(ind.Type()).Interface() + sc, ok := indi.(sql.Scanner) + if !ok { + return } - case sql.NullInt64: - err := indi.Scan(value) + err := sc.Scan(value) if err == nil { - ind.Set(reflect.ValueOf(indi)) - } - case sql.NullFloat64: - err := indi.Scan(value) - if err == nil { - ind.Set(reflect.ValueOf(indi)) - } - case sql.NullBool: - err := indi.Scan(value) - if err == nil { - ind.Set(reflect.ValueOf(indi)) + ind.Set(reflect.Indirect(reflect.ValueOf(sc))) } } } From d2c289193ab2dfb5c981f3da66f6adb2eefa10b6 Mon Sep 17 00:00:00 2001 From: zav8 Date: Tue, 18 Dec 2018 10:37:41 +0800 Subject: [PATCH 3/3] add test case for QueryRow and QueryRows --- orm/orm_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/orm/orm_test.go b/orm/orm_test.go index 89a714b6..4f499a7c 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1679,6 +1679,31 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(uid, 4)) throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) } // user_profile table @@ -1771,6 +1796,32 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].Age, 30)) + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) } func TestRawValues(t *testing.T) {