This commit is contained in:
slene 2013-12-30 23:04:13 +08:00
parent e0e8fa6e2a
commit 412a4a04de
2 changed files with 228 additions and 254 deletions

View File

@ -4,7 +4,6 @@ import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
@ -164,65 +163,11 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
}
}
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
sIdxes := *sIdxesPtr
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr
cur := 0
for i, idxs := range sIdxes {
for i := 0; i < len(sInds); i++ {
sInd := sInds[i]
eTyp := eTyps[i]
@ -258,32 +203,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value)
}
cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
}
} else {
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil {
@ -313,15 +234,12 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
}
func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -335,44 +253,120 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...)
if err := row.Scan(refs...); err == sql.ErrNoRows {
return ErrNoRows
} else if err != nil {
rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
} else {
return ErrNoRows
}
return nil
}
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0)
sIdxes := make([][]int, 0)
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@ -389,7 +383,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
@ -403,21 +410,97 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
var cnt int64
for rows.Next() {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++
}
if cnt > 0 {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
}

View File

@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
}
}
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var (
uid int
status *int
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
@ -1427,90 +1366,42 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu)
}
type Tmp struct {
Id int
Name string
Skiped0 string `orm:"-"`
Pid *int
Skiped1 Data
Skiped2 *Data
var datas2 []Data
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err = dORM.Raw(query).QueryRows(&datas2)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas2), 1))
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
var (
ids []int
userNames []string
profileIds1 []int
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
var ids []int
var usernames []string
num, err = dORM.Raw("SELECT id, user_name FROM user ORDER BY id asc").QueryRows(&ids, &usernames)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
var users []User
dORM.QueryTable("user").OrderBy("Id").All(&users)
for i := 0; i < 3; i++ {
id := ids[i]
name := userNames[i]
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
}
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
throwFailNow(t, AssertIs(len(ids), 3))
throwFailNow(t, AssertIs(ids[0], 2))
throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody"))
}
func TestRawValues(t *testing.T) {