diff --git a/orm/db.go b/orm/db.go index 4fa2b467..2958e345 100644 --- a/orm/db.go +++ b/orm/db.go @@ -49,10 +49,17 @@ type dbBase struct { ins dbBaser } -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { +var _ dbBaser = new(dbBase) + +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { _, pkValue, _ := getExistPk(mi, ind) - for _, column := range mi.fields.orders { - fi := mi.fields.columns[column] + for _, column := range cols { + var fi *fieldInfo + if fi, _ = mi.fields.GetByAny(column); fi != nil { + column = fi.column + } else { + panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + } if fi.dbcol == false || fi.auto && skipAuto { continue } @@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, } func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - _, values, err := d.collectValues(mi, ind, true, true, tz) + _, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) if err != nil { return 0, err } @@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } } -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error { - pkColumn, pkValue, ok := getExistPk(mi, ind) - if ok == false { - return ErrMissPK +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { + var whereCols []string + var args []interface{} + + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) + if err != nil { + return err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if ok == false { + return ErrMissPK + } + whereCols = append(whereCols, pkColumn) + args = append(args, pkValue) } Q := d.ins.TableQuote() @@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo sels := strings.Join(mi.fields.dbcols, sep) colsNum := len(mi.fields.dbcols) - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q) + sep = fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q) refs := make([]interface{}, colsNum) for i, _ := range refs { @@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, pkValue) + row := q.QueryRow(query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names, values, err := d.collectValues(mi, ind, true, true, tz) + names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) if err != nil { return 0, err } @@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. } } -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { return 0, ErrMissPK } - setNames, setValues, err := d.collectValues(mi, ind, true, false, tz) + + // if specify cols length is zero, then commit all columns. + if len(cols) == 0 { + cols = mi.fields.dbcols + } + + setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) if err != nil { return 0, err } @@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi Q := d.ins.TableQuote() + var tCols []string + if len(cols) > 0 { + hasRel := len(qs.related) > 0 || qs.relDepth > 0 + tCols = make([]string, 0, len(cols)) + var maps map[string]bool + if hasRel { + maps = make(map[string]bool) + } + for _, col := range cols { + if fi, ok := mi.fields.GetByAny(col); ok { + tCols = append(tCols, fi.column) + if hasRel { + maps[fi.column] = true + } + } else { + panic(fmt.Sprintf("wrong field/column name `%s`", col)) + } + } + if hasRel { + for _, fi := range mi.fields.fieldsDB { + if fi.fieldType&IsRelField > 0 { + if maps[fi.column] == false { + tCols = append(tCols, fi.column) + } + } + } + } + } else { + tCols = mi.fields.dbcols + } + + colsNum := len(tCols) + sep := fmt.Sprintf("%s, T0.%s", Q, Q) + sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) + tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi limit := tables.getLimitSql(mi, offset, rlimit) join := tables.getJoinSql() - colsNum := len(mi.fields.dbcols) - sep := fmt.Sprintf("%s, T0.%s", Q, Q) - cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q) for _, tbl := range tables.tables { if tbl.sel { colsNum += len(tbl.mi.fields.dbcols) sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) - cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) } } - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cacheM := make(map[string]*modelInfo) trefs := refs - d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz) - trefs = refs[len(mi.fields.dbcols):] + d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) + trefs = refs[len(tCols):] for _, tbl := range tables.tables { if tbl.sel { diff --git a/orm/orm.go b/orm/orm.go index c4a52631..51cdc495 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { panic(fmt.Sprintf(" table: `%s` not found, maybe not RegisterModel", name)) } -func (o *orm) Read(md interface{}) error { +func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return err } @@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) { return id, nil } -func (o *orm) Update(md interface{}) (int64, error) { +func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md) - num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ) + num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 4e046bbf..61adaff8 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) { return newInsertSet(o.orm, o.mi) } -func (o *querySet) All(container interface{}) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) +func (o *querySet) All(container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } -func (o *querySet) One(container interface{}) error { - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) +func (o *querySet) One(container interface{}, cols ...string) error { + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } diff --git a/orm/orm_test.go b/orm/orm_test.go index 92952fbf..d734b9fb 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -20,91 +20,40 @@ var ( test_DateTime = format_DateTime + " -0700" ) -type T_Code int - -const ( - // = - T_Equal T_Code = iota - // < - T_Less - // > - T_Large - // elment in slice/array - // T_In - // key exists in map - // T_KeyExist - // index != -1 - // T_Contain - // index == 0 - // T_StartWith - // index == len(x) - 1 - // T_EndWith -) - -func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err error, ok bool) { +func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) { if len(args) == 0 { return fmt.Errorf("miss args"), false } b := args[0] arg := argAny(args) - switch o { - case T_Equal: - switch v := a.(type) { - case reflect.Kind: - ok = reflect.ValueOf(b).Kind() == v - case time.Time: - if v2, vo := b.(time.Time); vo { - if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) - a = v.Format(format) - b = v2.Format(format) - ok = a == b - } else { - err = fmt.Errorf("compare datetime miss format") - goto wrongArg - } - } - default: - ok = ToStr(a) == ToStr(b) - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b } else { - err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a) + err = fmt.Errorf("compare datetime miss format") + goto wrongArg } } - case T_Less, T_Large: - as := ToStr(a) - bs := ToStr(b) - f1, er := StrTo(as).Float64() - if er != nil { - err = fmt.Errorf("wrong type need numeric: `%v`", a) - goto wrongArg - } - f2, er := StrTo(bs).Float64() - if er != nil { - err = fmt.Errorf("wrong type need numeric: `%v`", b) - goto wrongArg - } - var opts []string - if o == T_Less { - opts = []string{"<", ">="} - ok = f1 < f2 + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a) } else { - opts = []string{">", "<="} - ok = f1 > f2 - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[0], f1, f2) - } else { - err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[1], f1, f2) - } + err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a) } } + wrongArg: if err != nil { return err, false @@ -113,15 +62,15 @@ wrongArg: return nil, true } -func AssertIs(a interface{}, o T_Code, args ...interface{}) error { - if err, ok := ValuesCompare(true, a, o, args...); ok == false { +func AssertIs(a interface{}, args ...interface{}) error { + if err, ok := ValuesCompare(true, a, args...); ok == false { return err } return nil } -func AssertNot(a interface{}, o T_Code, args ...interface{}) error { - if err, ok := ValuesCompare(false, a, o, args...); ok == false { +func AssertNot(a interface{}, args ...interface{}) error { + if err, ok := ValuesCompare(false, a, args...); ok == false { return err } return nil @@ -224,12 +173,12 @@ func TestModelSyntax(t *testing.T) { ind := reflect.ValueOf(user).Elem() fn := getFullName(ind.Type()) mi, ok := modelCache.getByFN(fn) - throwFail(t, AssertIs(ok, T_Equal, true)) + throwFail(t, AssertIs(ok, true)) mi, ok = modelCache.get("user") - throwFail(t, AssertIs(ok, T_Equal, true)) + throwFail(t, AssertIs(ok, true)) if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true)) + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) } } @@ -267,7 +216,7 @@ func TestDataTypes(t *testing.T) { id, err := dORM.Insert(&d) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) d = Data{Id: 1} err = dORM.Read(&d) @@ -286,7 +235,7 @@ func TestDataTypes(t *testing.T) { vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) } - throwFail(t, AssertIs(vu == value, T_Equal, true), value, vu) + throwFail(t, AssertIs(vu == value, true), value, vu) } } @@ -301,7 +250,7 @@ func TestNullDataTypes(t *testing.T) { id, err := dORM.Insert(&d) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) d = DataNull{Id: 1} err = dORM.Read(&d) @@ -321,7 +270,7 @@ func TestCRUD(t *testing.T) { profile.Money = 1234.12 id, err := dORM.Insert(profile) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) user := NewUser() user.UserName = "slene" @@ -333,63 +282,77 @@ func TestCRUD(t *testing.T) { id, err = dORM.Insert(user) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) u := &User{Id: user.Id} err = dORM.Read(u) throwFail(t, err) - throwFail(t, AssertIs(u.UserName, T_Equal, "slene")) - throwFail(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com")) - throwFail(t, AssertIs(u.Password, T_Equal, "pass")) - throwFail(t, AssertIs(u.Status, T_Equal, 3)) - throwFail(t, AssertIs(u.IsStaff, T_Equal, true)) - throwFail(t, AssertIs(u.IsActive, T_Equal, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), T_Equal, user.Created.In(DefaultTimeLoc), test_Date)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), T_Equal, user.Updated.In(DefaultTimeLoc), test_DateTime)) + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime)) user.UserName = "astaxie" user.Profile = profile num, err := dORM.Update(user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) u = &User{Id: user.Id} err = dORM.Read(u) - throwFail(t, err) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.Id, profile.Id)) - if err == nil { - throwFail(t, AssertIs(u.UserName, T_Equal, "astaxie")) - throwFail(t, AssertIs(u.Profile.Id, T_Equal, profile.Id)) - } + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{Id: user.Id} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) num, err = dORM.Delete(profile) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) u = &User{Id: user.Id} err = dORM.Read(u) throwFail(t, err) - throwFail(t, AssertIs(true, T_Equal, u.Profile == nil)) + throwFail(t, AssertIs(true, u.Profile == nil)) num, err = dORM.Delete(user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) u = &User{Id: 100} err = dORM.Read(u) - throwFail(t, AssertIs(err, T_Equal, ErrNoRows)) + throwFail(t, AssertIs(err, ErrNoRows)) ub := UserBig{} ub.Name = "name" id, err = dORM.Insert(&ub) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) ub = UserBig{Id: 1} err = dORM.Read(&ub) throwFail(t, err) - throwFail(t, AssertIs(ub.Name, T_Equal, "name")) + throwFail(t, AssertIs(ub.Name, "name")) } func TestInsertTestData(t *testing.T) { @@ -401,7 +364,7 @@ func TestInsertTestData(t *testing.T) { id, err := dORM.Insert(profile) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 2)) + throwFail(t, AssertIs(id, 2)) user := NewUser() user.UserName = "slene" @@ -416,7 +379,7 @@ func TestInsertTestData(t *testing.T) { id, err = dORM.Insert(user) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 2)) + throwFail(t, AssertIs(id, 2)) profile = NewProfile() profile.Age = 30 @@ -424,7 +387,7 @@ func TestInsertTestData(t *testing.T) { id, err = dORM.Insert(profile) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 3)) + throwFail(t, AssertIs(id, 3)) user = NewUser() user.UserName = "astaxie" @@ -439,7 +402,7 @@ func TestInsertTestData(t *testing.T) { id, err = dORM.Insert(user) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 3)) + throwFail(t, AssertIs(id, 3)) user = NewUser() user.UserName = "nobody" @@ -453,7 +416,7 @@ func TestInsertTestData(t *testing.T) { id, err = dORM.Insert(user) throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, 4)) + throwFail(t, AssertIs(id, 4)) tags := []*Tag{ &Tag{Name: "golang"}, @@ -484,20 +447,20 @@ The program—and web server—godoc processes Go source files to extract docume for _, tag := range tags { id, err := dORM.Insert(tag) throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) } for _, post := range posts { id, err := dORM.Insert(post) throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) // dORM.M2mAdd(post, "tags", post.Tags) } for _, comment := range comments { id, err := dORM.Insert(comment) throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) } } @@ -508,34 +471,34 @@ func TestExpr(t *testing.T) { qs = dORM.QueryTable("user") num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("created", time.Now()).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) } func TestOperators(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.Filter("user_name", "slene").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name__exact", "slene").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name__iexact", "Slene").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name__contains", "e").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) var shouldNum int @@ -547,35 +510,35 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("user_name__contains", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, shouldNum)) + throwFail(t, AssertIs(num, shouldNum)) num, err = qs.Filter("user_name__icontains", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("user_name__icontains", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("status__gt", 1).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("status__gte", 1).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) num, err = qs.Filter("status__lt", 3).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("status__lte", 3).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) num, err = qs.Filter("user_name__startswith", "s").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) if IsSqlite { shouldNum = 1 @@ -585,15 +548,15 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("user_name__startswith", "S").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, shouldNum)) + throwFail(t, AssertIs(num, shouldNum)) num, err = qs.Filter("user_name__istartswith", "S").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name__endswith", "e").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) if IsSqlite { shouldNum = 2 @@ -603,28 +566,28 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("user_name__endswith", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, shouldNum)) + throwFail(t, AssertIs(num, shouldNum)) num, err = qs.Filter("user_name__iendswith", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("profile__isnull", true).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("status__in", 1, 2).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Filter("status__in", []int{1, 2}).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) n1, n2 := 1, 2 num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) } func TestAll(t *testing.T) { @@ -632,41 +595,56 @@ func TestAll(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.OrderBy("Id").All(&users) throwFail(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 3)) + throwFailNow(t, AssertIs(num, 3)) - throwFail(t, AssertIs(users[0].UserName, T_Equal, "slene")) - throwFail(t, AssertIs(users[1].UserName, T_Equal, "astaxie")) - throwFail(t, AssertIs(users[2].UserName, T_Equal, "nobody")) + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) var users2 []User qs = dORM.QueryTable("user") num, err = qs.OrderBy("Id").All(&users2) throwFail(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 3)) + throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(users2[0].UserName, T_Equal, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, T_Equal, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, T_Equal, "nobody")) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].Id, 0)) + throwFailNow(t, AssertIs(users2[1].Id, 0)) + throwFailNow(t, AssertIs(users2[2].Id, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 0)) + throwFail(t, AssertIs(num, 0)) } func TestOne(t *testing.T) { var user User qs := dORM.QueryTable("user") err := qs.One(&user) - throwFail(t, AssertIs(err, T_Equal, ErrMultiRows)) + throwFail(t, AssertIs(err, ErrMultiRows)) user = User{} err = qs.OrderBy("Id").Limit(1).One(&user) throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, T_Equal, "slene")) + throwFail(t, AssertIs(user.UserName, "slene")) err = qs.Filter("user_name", "nothing").One(&user) - throwFail(t, AssertIs(err, T_Equal, ErrNoRows)) + throwFail(t, AssertIs(err, ErrNoRows)) } @@ -676,19 +654,19 @@ func TestValues(t *testing.T) { num, err := qs.Values(&maps) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene")) - throwFail(t, AssertIs(maps[2]["Profile"], T_Equal, nil)) + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) } num, err = qs.Values(&maps, "UserName", "Profile__Age") throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene")) - throwFail(t, AssertIs(maps[0]["Profile__Age"], T_Equal, 28)) - throwFail(t, AssertIs(maps[2]["Profile__Age"], T_Equal, nil)) + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) } } @@ -698,19 +676,19 @@ func TestValuesList(t *testing.T) { num, err := qs.ValuesList(&list) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0][1], T_Equal, "slene")) - throwFail(t, AssertIs(list[2][9], T_Equal, nil)) + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) } num, err = qs.ValuesList(&list, "UserName", "Profile__Age") throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0][0], T_Equal, "slene")) - throwFail(t, AssertIs(list[0][1], T_Equal, 28)) - throwFail(t, AssertIs(list[2][1], T_Equal, nil)) + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) } } @@ -720,11 +698,11 @@ func TestValuesFlat(t *testing.T) { num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0], T_Equal, "slene")) - throwFail(t, AssertIs(list[1], T_Equal, "astaxie")) - throwFail(t, AssertIs(list[2], T_Equal, "nobody")) + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) } } @@ -732,41 +710,41 @@ func TestRelatedSel(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.Filter("profile__age", 28).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("profile__age__gt", 28).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) var user User err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) - throwFail(t, AssertNot(user.Profile, T_Equal, nil)) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28)) + throwFail(t, AssertIs(user.Profile.Age, 28)) } err = qs.Filter("user_name", "slene").RelatedSel().One(&user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) - throwFail(t, AssertNot(user.Profile, T_Equal, nil)) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28)) + throwFail(t, AssertIs(user.Profile.Age, 28)) } err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) - throwFail(t, AssertIs(num, T_Equal, 1)) - throwFail(t, AssertIs(user.Profile, T_Equal, nil)) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) qs = dORM.QueryTable("user_profile") num, err = qs.Filter("user__username", "slene").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) } func TestSetCond(t *testing.T) { @@ -776,12 +754,12 @@ func TestSetCond(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.SetCond(cond1).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) num, err = qs.SetCond(cond2).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) } func TestLimit(t *testing.T) { @@ -789,19 +767,19 @@ func TestLimit(t *testing.T) { qs := dORM.QueryTable("post") num, err := qs.Limit(1).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Limit(-1).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 4)) + throwFail(t, AssertIs(num, 4)) num, err = qs.Limit(-1, 2).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) num, err = qs.Limit(0, 2).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) } func TestOffset(t *testing.T) { @@ -809,26 +787,26 @@ func TestOffset(t *testing.T) { qs := dORM.QueryTable("post") num, err := qs.Limit(1).Offset(2).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.Offset(2).All(&posts) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) } func TestOrderBy(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) } func TestPrepareInsert(t *testing.T) { @@ -840,21 +818,21 @@ func TestPrepareInsert(t *testing.T) { user.UserName = "testing1" num, err := i.Insert(&user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Large, 0)) + throwFail(t, AssertIs(num > 0, true)) user.UserName = "testing2" num, err = i.Insert(&user) throwFail(t, err) - throwFail(t, AssertIs(num, T_Large, 0)) + throwFail(t, AssertIs(num > 0, true)) num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 2)) + throwFail(t, AssertIs(num, 2)) err = i.Close() throwFail(t, err) err = i.Close() - throwFail(t, AssertIs(err, T_Equal, ErrStmtClosed)) + throwFail(t, AssertIs(err, ErrStmtClosed)) } func TestRawExec(t *testing.T) { @@ -864,12 +842,12 @@ func TestRawExec(t *testing.T) { res, err := dORM.Raw(query, "testing", "slene").Exec() throwFail(t, err) num, err := res.RowsAffected() - throwFail(t, AssertIs(num, T_Equal, 1), err) + throwFail(t, AssertIs(num, 1), err) res, err = dORM.Raw(query, "slene", "testing").Exec() throwFail(t, err) num, err = res.RowsAffected() - throwFail(t, AssertIs(num, T_Equal, 1), err) + throwFail(t, AssertIs(num, 1), err) } func TestRawQueryRow(t *testing.T) { @@ -922,17 +900,17 @@ func TestRawQueryRow(t *testing.T) { v := reflect.ValueOf(vu).Elem().Interface() switch col { case "id": - throwFail(t, AssertIs(id, T_Equal, 1)) + throwFail(t, AssertIs(id, 1)) case "date": v = v.(time.Time).In(DefaultTimeLoc) value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, T_Equal, value, test_Date)) + throwFail(t, AssertIs(v, value, test_Date)) case "datetime": v = v.(time.Time).In(DefaultTimeLoc) value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, T_Equal, value, test_DateTime)) + throwFail(t, AssertIs(v, value, test_DateTime)) default: - throwFail(t, AssertIs(v, T_Equal, data_values[col])) + throwFail(t, AssertIs(v, data_values[col])) } } @@ -965,26 +943,26 @@ func TestRawQueryRow(t *testing.T) { for _, col := range cols { switch col { case "id": - throwFail(t, AssertIs(tmp.Id, T_Equal, data_values[col])) + throwFail(t, AssertIs(tmp.Id, data_values[col])) case "char": c := tmp.Char - throwFail(t, AssertIs(*c, T_Equal, data_values[col])) + throwFail(t, AssertIs(*c, data_values[col])) case "date": v := tmp.Date.In(DefaultTimeLoc) value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, T_Equal, value, test_Date)) + throwFail(t, AssertIs(v, value, test_Date)) case "datetime": v := tmp.DateTime.In(DefaultTimeLoc) value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, T_Equal, value, test_DateTime)) + throwFail(t, AssertIs(v, value, test_DateTime)) case "boolean": - throwFail(t, AssertIs(Boolean, T_Equal, data_values[col])) + throwFail(t, AssertIs(Boolean, data_values[col])) case "text": - throwFail(t, AssertIs(Text, T_Equal, data_values[col])) + throwFail(t, AssertIs(Text, data_values[col])) case "int64": - throwFail(t, AssertIs(Int64, T_Equal, data_values[col])) + throwFail(t, AssertIs(Int64, data_values[col])) case "uint": - throwFail(t, AssertIs(Uint, T_Equal, data_values[col])) + throwFail(t, AssertIs(Uint, data_values[col])) } } @@ -1000,9 +978,9 @@ func TestRawQueryRow(t *testing.T) { query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) throwFail(t, err) - throwFail(t, AssertIs(uid, T_Equal, 4)) - throwFail(t, AssertIs(*status, T_Equal, 3)) - throwFail(t, AssertIs(pid, T_Equal, nil)) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) } func TestQueryRows(t *testing.T) { @@ -1020,10 +998,10 @@ func TestQueryRows(t *testing.T) { query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) num, err := dORM.Raw(query).QueryRows(&datas, &dids) throwFailNow(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 1)) - throwFailNow(t, AssertIs(len(datas), T_Equal, 1)) - throwFailNow(t, AssertIs(len(dids), T_Equal, 1)) - throwFailNow(t, AssertIs(dids[0], T_Equal, 1)) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + throwFailNow(t, AssertIs(len(dids), 1)) + throwFailNow(t, AssertIs(dids[0], 1)) ind := reflect.Indirect(reflect.ValueOf(datas[0])) @@ -1038,7 +1016,7 @@ func TestQueryRows(t *testing.T) { vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) } - throwFail(t, AssertIs(vu == value, T_Equal, true), value, vu) + throwFail(t, AssertIs(vu == value, true), value, vu) } type Tmp struct { @@ -1066,7 +1044,7 @@ func TestQueryRows(t *testing.T) { query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q) num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds) throwFailNow(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 3)) + throwFailNow(t, AssertIs(num, 3)) var users []User dORM.QueryTable("user").OrderBy("Id").All(&users) @@ -1080,37 +1058,37 @@ func TestQueryRows(t *testing.T) { updated := updateds[i] user := users[i] - throwFailNow(t, AssertIs(id, T_Equal, user.Id)) - throwFailNow(t, AssertIs(name, T_Equal, user.UserName)) + throwFailNow(t, AssertIs(id, user.Id)) + throwFailNow(t, AssertIs(name, user.UserName)) if user.Profile != nil { - throwFailNow(t, AssertIs(pid1, T_Equal, user.Profile.Id)) - throwFailNow(t, AssertIs(*pid2, T_Equal, user.Profile.Id)) + throwFailNow(t, AssertIs(pid1, user.Profile.Id)) + throwFailNow(t, AssertIs(*pid2, user.Profile.Id)) } else { - throwFailNow(t, AssertIs(pid1, T_Equal, 0)) - throwFailNow(t, AssertIs(pid2, T_Equal, nil)) + throwFailNow(t, AssertIs(pid1, 0)) + throwFailNow(t, AssertIs(pid2, nil)) } - throwFailNow(t, AssertIs(created, T_Equal, user.Created, test_Date)) - throwFailNow(t, AssertIs(updated, T_Equal, user.Updated, test_DateTime)) + throwFailNow(t, AssertIs(created, user.Created, test_Date)) + throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime)) tmp := tmps1[i] tmp1 := *tmp - throwFailNow(t, AssertIs(tmp1.Id, T_Equal, user.Id)) - throwFailNow(t, AssertIs(tmp1.Name, T_Equal, user.UserName)) + throwFailNow(t, AssertIs(tmp1.Id, user.Id)) + throwFailNow(t, AssertIs(tmp1.Name, user.UserName)) if user.Profile != nil { pid := tmp1.Pid - throwFailNow(t, AssertIs(*pid, T_Equal, user.Profile.Id)) + throwFailNow(t, AssertIs(*pid, user.Profile.Id)) } else { - throwFailNow(t, AssertIs(tmp1.Pid, T_Equal, nil)) + throwFailNow(t, AssertIs(tmp1.Pid, nil)) } tmp2 := tmps2[i] - throwFailNow(t, AssertIs(tmp2.Id, T_Equal, user.Id)) - throwFailNow(t, AssertIs(tmp2.Name, T_Equal, user.UserName)) + throwFailNow(t, AssertIs(tmp2.Id, user.Id)) + throwFailNow(t, AssertIs(tmp2.Name, user.UserName)) if user.Profile != nil { pid := tmp2.Pid - throwFailNow(t, AssertIs(*pid, T_Equal, user.Profile.Id)) + throwFailNow(t, AssertIs(*pid, user.Profile.Id)) } else { - throwFailNow(t, AssertIs(tmp2.Pid, T_Equal, nil)) + throwFailNow(t, AssertIs(tmp2.Pid, nil)) } } @@ -1123,8 +1101,8 @@ func TestQueryRows(t *testing.T) { query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q) num, err = dORM.Raw(query).QueryRows(&tmp) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) - throwFail(t, AssertIs(tmp[0], T_Equal, nil)) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(tmp[0], nil)) } func TestRawValues(t *testing.T) { @@ -1134,28 +1112,28 @@ func TestRawValues(t *testing.T) { query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q) num, err := dORM.Raw(query, 1).Values(&maps) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) if num == 1 { - throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene")) + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) } var lists []ParamsList num, err = dORM.Raw(query, 1).ValuesList(&lists) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) if num == 1 { - throwFail(t, AssertIs(lists[0][0], T_Equal, "slene")) + throwFail(t, AssertIs(lists[0][0], "slene")) } query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) var list ParamsList num, err = dORM.Raw(query).ValuesFlat(&list) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0], T_Equal, "2")) - throwFail(t, AssertIs(list[1], T_Equal, "3")) - throwFail(t, AssertIs(list[2], T_Equal, nil)) + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) } } @@ -1171,21 +1149,21 @@ func TestRawPrepare(t *testing.T) { tid, err := r.LastInsertId() throwFail(t, err) - throwFail(t, AssertIs(tid, T_Large, 0)) + throwFail(t, AssertIs(tid > 0, true)) r, err = pre.Exec("name2") throwFail(t, err) id, err := r.LastInsertId() throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, tid+1)) + throwFail(t, AssertIs(id, tid+1)) r, err = pre.Exec("name3") throwFail(t, err) id, err = r.LastInsertId() throwFail(t, err) - throwFail(t, AssertIs(id, T_Equal, tid+2)) + throwFail(t, AssertIs(id, tid+2)) err = pre.Close() throwFail(t, err) @@ -1195,7 +1173,7 @@ func TestRawPrepare(t *testing.T) { num, err := res.RowsAffected() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) } case IsPostgres: @@ -1221,7 +1199,7 @@ func TestRawPrepare(t *testing.T) { if err == nil { num, err := res.RowsAffected() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + throwFail(t, AssertIs(num, 3)) } } } @@ -1233,26 +1211,26 @@ func TestUpdate(t *testing.T) { "is_staff": true, }) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) // with join num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ "is_staff": false, }) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) } func TestDelete(t *testing.T) { qs := dORM.QueryTable("user_profile") num, err := qs.Filter("user__user_name", "slene").Delete() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) } func TestTransaction(t *testing.T) { @@ -1268,11 +1246,11 @@ func TestTransaction(t *testing.T) { tag.Name = names[0] id, err := o.Insert(&tag) throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) switch { case IsMysql || IsSqlite: @@ -1281,7 +1259,7 @@ func TestTransaction(t *testing.T) { if err == nil { id, err = res.LastInsertId() throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) } } @@ -1290,7 +1268,7 @@ func TestTransaction(t *testing.T) { num, err = o.QueryTable("tag").Filter("name__in", names).Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 0)) + throwFail(t, AssertIs(num, 0)) err = o.Begin() throwFail(t, err) @@ -1298,13 +1276,13 @@ func TestTransaction(t *testing.T) { tag.Name = "commit" id, err = o.Insert(&tag) throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + throwFail(t, AssertIs(id > 0, true)) o.Commit() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + throwFail(t, AssertIs(num, 1)) } diff --git a/orm/types.go b/orm/types.go index ce25d037..6a503511 100644 --- a/orm/types.go +++ b/orm/types.go @@ -20,9 +20,9 @@ type Fielder interface { } type Ormer interface { - Read(interface{}) error + Read(interface{}, ...string) error Insert(interface{}) (int64, error) - Update(interface{}) (int64, error) + Update(interface{}, ...string) (int64, error) Delete(interface{}) (int64, error) M2mAdd(interface{}, string, ...interface{}) (int64, error) M2mDel(interface{}, string, ...interface{}) (int64, error) @@ -53,8 +53,8 @@ type QuerySeter interface { Update(Params) (int64, error) Delete() (int64, error) PrepareInsert() (Inserter, error) - All(interface{}) (int64, error) - One(interface{}) error + All(interface{}, ...string) (int64, error) + One(interface{}, ...string) error Values(*[]Params, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error) @@ -111,12 +111,12 @@ type txEnder interface { } type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) SupportUpdateJoin() bool UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) diff --git a/orm/utils.go b/orm/utils.go index 8bc093d1..e73d50f6 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) { return strconv.ParseFloat(f.String(), 64) } +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + func (f StrTo) Int8() (int8, error) { v, err := strconv.ParseInt(f.String(), 10, 8) return int8(v), err @@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) { return int64(v), err } +func (f StrTo) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + func (f StrTo) Uint8() (uint8, error) { v, err := strconv.ParseUint(f.String(), 10, 8) return uint8(v), err