diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 864515ac..7d204876 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "reflect" - "strings" "time" ) @@ -164,65 +163,11 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } -func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) { - sIdxes := *sIdxesPtr - refs := *refsPtr - - if typ.Kind() == reflect.Struct { - if typ.String() == "time.Time" { - var ref interface{} - refs = append(refs, &ref) - sIdxes = append(sIdxes, []int{0}) - } else { - idxs := []int{} - outFor: - for idx := 0; idx < typ.NumField(); idx++ { - ctyp := typ.Field(idx) - - tag := ctyp.Tag.Get(defaultStructTagName) - for _, v := range strings.Split(tag, defaultStructTagDelim) { - if v == "-" { - continue outFor - } - } - - tp := ctyp.Type - if tp.Kind() == reflect.Ptr { - tp = tp.Elem() - } - - if tp.String() == "time.Time" { - var ref interface{} - refs = append(refs, &ref) - - } else if tp.Kind() != reflect.Struct { - var ref interface{} - refs = append(refs, &ref) - - } else { - // skip other type - continue - } - - idxs = append(idxs, idx) - } - sIdxes = append(sIdxes, idxs) - } - } else { - var ref interface{} - refs = append(refs, &ref) - sIdxes = append(sIdxes, []int{0}) - } - - *sIdxesPtr = sIdxes - *refsPtr = refs -} - -func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { +func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { nInds := *nIndsPtr cur := 0 - for i, idxs := range sIdxes { + for i := 0; i < len(sInds); i++ { sInd := sInds[i] eTyp := eTyps[i] @@ -258,32 +203,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect o.setFieldValue(ind, value) } cur++ - } else { - hasValue := false - for _, idx := range idxs { - tind := ind.Field(idx) - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if value != nil { - hasValue = true - } - if tind.Kind() == reflect.Ptr { - if value == nil { - tindV := reflect.New(tind.Type()).Elem() - tind.Set(tindV) - } else { - tindV := reflect.New(tind.Type().Elem()) - o.setFieldValue(tindV.Elem(), value) - tind.Set(tindV) - } - } else { - o.setFieldValue(tind, value) - } - cur++ - } - if hasValue == false && isPtr { - val = reflect.New(val.Type()).Elem() - } } + } else { value := reflect.ValueOf(refs[cur]).Elem().Interface() if isPtr && value == nil { @@ -313,15 +234,12 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect } func (o *rawSet) QueryRow(containers ...interface{}) error { - if len(containers) == 0 { - panic(fmt.Errorf(" need at least one arg")) - } - refs := make([]interface{}, 0, len(containers)) - sIdxes := make([][]int, 0) sInds := make([]reflect.Value, 0) eTyps := make([]reflect.Type, 0) + structMode := false + var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -335,44 +253,120 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { if typ.Kind() == reflect.Ptr { typ = typ.Elem() } - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } sInds = append(sInds, ind) eTyps = append(eTyps, etyp) - o.loopInitRefs(typ, &refs, &sIdxes) + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFN(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } } query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) args := getFlatParams(nil, o.args, o.orm.alias.TZ) - row := o.orm.db.QueryRow(query, args...) - - if err := row.Scan(refs...); err == sql.ErrNoRows { - return ErrNoRows - } else if err != nil { + rows, err := o.orm.db.Query(query, args...) + if err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } return err } - nInds := make([]reflect.Value, len(sInds)) - o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true) - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) + if rows.Next() { + if structMode { + columns, err := rows.Columns() + if err != nil { + return err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return err + } + + ind := sInds[0] + + if ind.Kind() == reflect.Ptr { + if ind.IsNil() || !ind.IsValid() { + ind.Set(reflect.New(eTyps[0].Elem())) + } + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + + var attrs map[string]bool + var tags map[string]string + parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + var col string + if col = tags["column"]; len(col) == 0 { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + } else { + if err := rows.Scan(refs...); err != nil { + return err + } + + nInds := make([]reflect.Value, len(sInds)) + o.loopSetRefs(refs, sInds, &nInds, eTyps, true) + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + + } else { + return ErrNoRows } return nil } func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { - refs := make([]interface{}, 0) - sIdxes := make([][]int, 0) + refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) eTyps := make([]reflect.Type, 0) + structMode := false + var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) sInd := reflect.Indirect(val) @@ -389,7 +383,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { sInds = append(sInds, sInd) eTyps = append(eTyps, etyp) - o.loopInitRefs(typ, &refs, &sIdxes) + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFN(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } } query := o.query @@ -403,21 +410,97 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + var cnt int64 for rows.Next() { - if err := rows.Scan(refs...); err != nil { - return 0, err - } - o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0) + if structMode { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 && !sInd.IsNil() { + sInd.Set(reflect.New(sInd.Type()).Elem()) + } + + var ind reflect.Value + if eTyps[0].Kind() == reflect.Ptr { + ind = reflect.New(eTyps[0].Elem()) + } else { + ind = reflect.New(eTyps[0]) + } + + if ind.Kind() == reflect.Ptr { + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + + var attrs map[string]bool + var tags map[string]string + parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + var col string + if col = tags["column"]; len(col) == 0 { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + if eTyps[0].Kind() == reflect.Ptr { + ind = ind.Addr() + } + + sInd = reflect.Append(sInd, ind) + + } else { + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) + } cnt++ } if cnt > 0 { - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) + + if structMode { + sInds[0].Set(sInd) + } else { + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } } } diff --git a/orm/orm_test.go b/orm/orm_test.go index d92e3fab..410aa484 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) { } } - type Tmp struct { - Skip0 string - Id int - Char *string - Skip1 int `orm:"-"` - Date time.Time - DateTime time.Time - } - - Boolean = false - Text = "" - Int64 = 0 - Uint = 0 - - tmp := new(Tmp) - - cols = []string{ - "int", "char", "date", "datetime", "boolean", "text", "int64", "uint", - } - query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - values = []interface{}{ - tmp, &Boolean, &Text, &Int64, &Uint, - } - err = dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - - for _, col := range cols { - switch col { - case "id": - throwFail(t, AssertIs(tmp.Id, data_values[col])) - case "char": - c := tmp.Char - throwFail(t, AssertIs(*c, data_values[col])) - case "date": - v := tmp.Date.In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_Date)) - case "datetime": - v := tmp.DateTime.In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_DateTime)) - case "boolean": - throwFail(t, AssertIs(Boolean, data_values[col])) - case "text": - throwFail(t, AssertIs(Text, data_values[col])) - case "int64": - throwFail(t, AssertIs(Int64, data_values[col])) - case "uint": - throwFail(t, AssertIs(Uint, data_values[col])) - } - } - var ( uid int status *int @@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) { func TestQueryRows(t *testing.T) { Q := dDbBaser.TableQuote() - cols := []string{ - "id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - var datas []*Data - var dids []int - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas, &dids) + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(datas), 1)) - throwFailNow(t, AssertIs(len(dids), 1)) - throwFailNow(t, AssertIs(dids[0], 1)) ind := reflect.Indirect(reflect.ValueOf(datas[0])) @@ -1427,90 +1366,42 @@ func TestQueryRows(t *testing.T) { throwFail(t, AssertIs(vu == value, true), value, vu) } - type Tmp struct { - Id int - Name string - Skiped0 string `orm:"-"` - Pid *int - Skiped1 Data - Skiped2 *Data + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range Data_Values { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) } - var ( - ids []int - userNames []string - profileIds1 []int - profileIds2 []*int - createds []time.Time - updateds []time.Time - tmps1 []*Tmp - tmps2 []Tmp - ) - cols = []string{ - "id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds) + var ids []int + var usernames []string + num, err = dORM.Raw("SELECT id, user_name FROM user ORDER BY id asc").QueryRows(&ids, &usernames) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) - - var users []User - dORM.QueryTable("user").OrderBy("Id").All(&users) - - for i := 0; i < 3; i++ { - id := ids[i] - name := userNames[i] - pid1 := profileIds1[i] - pid2 := profileIds2[i] - created := createds[i] - updated := updateds[i] - - user := users[i] - throwFailNow(t, AssertIs(id, user.Id)) - throwFailNow(t, AssertIs(name, user.UserName)) - if user.Profile != nil { - throwFailNow(t, AssertIs(pid1, user.Profile.Id)) - throwFailNow(t, AssertIs(*pid2, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(pid1, 0)) - throwFailNow(t, AssertIs(pid2, nil)) - } - throwFailNow(t, AssertIs(created, user.Created, test_Date)) - throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime)) - - tmp := tmps1[i] - tmp1 := *tmp - throwFailNow(t, AssertIs(tmp1.Id, user.Id)) - throwFailNow(t, AssertIs(tmp1.Name, user.UserName)) - if user.Profile != nil { - pid := tmp1.Pid - throwFailNow(t, AssertIs(*pid, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(tmp1.Pid, nil)) - } - - tmp2 := tmps2[i] - throwFailNow(t, AssertIs(tmp2.Id, user.Id)) - throwFailNow(t, AssertIs(tmp2.Name, user.UserName)) - if user.Profile != nil { - pid := tmp2.Pid - throwFailNow(t, AssertIs(*pid, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(tmp2.Pid, nil)) - } - } - - type Sec struct { - Id int - Name string - } - - var tmp []*Sec - query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q) - num, err = dORM.Raw(query).QueryRows(&tmp) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(tmp[0], nil)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) } func TestRawValues(t *testing.T) {