From 8563000235ca7e8a96ef9ea7bf2faa2da481db2f Mon Sep 17 00:00:00 2001 From: slene Date: Thu, 8 Aug 2013 22:34:18 +0800 Subject: [PATCH] orm operator args now support multi types eg: []int []*int *int, Model *Model --- orm/db.go | 81 ++++++++++++++++++++++++++++++++++++------- orm/docs/zh/Object.md | 2 +- orm/models.go | 16 ++++++--- orm/orm_test.go | 53 ++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 18 deletions(-) diff --git a/orm/db.go b/orm/db.go index 8e23e386..9f1493ed 100644 --- a/orm/db.go +++ b/orm/db.go @@ -949,21 +949,76 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } -func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { - params := make([]interface{}, len(args)) - copy(params, args) - sql := "" - for i, arg := range args { - if md, ok := arg.(Modeler); ok { - ind := reflect.Indirect(reflect.ValueOf(md)) - if _, vu, exist := d.existPk(mi, ind); exist { - arg = vu - } else { - panic(fmt.Sprintf("`%s` need a valid args value", operator)) - } +func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) { + for _, arg := range args { + val := reflect.ValueOf(arg) + + if arg == nil { + params = append(params, arg) + continue } - params[i] = arg + + kind := val.Kind() + + switch kind { + case reflect.Slice, reflect.Array: + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } + + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := d.getOperatorParams(operator, args) + params = append(params, p...) + } + + case reflect.Ptr, reflect.Struct: + ind := reflect.Indirect(val) + + if ind.Kind() == reflect.Struct { + typ := ind.Type() + fullName := typ.PkgPath() + "." + typ.Name() + var value interface{} + if mmi, ok := modelCache.get(fullName); ok { + if _, vu, exist := d.existPk(mmi, ind); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%v`", operator, val.Type())) + } + } else { + arg = ind.Interface() + } + + params = append(params, arg) + + default: + params = append(params, arg) + } + } + + return +} + +func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { + sql := "" + params := d.getOperatorParams(operator, args) + if operator == "in" { marks := make([]string, len(params)) for i, _ := range marks { diff --git a/orm/docs/zh/Object.md b/orm/docs/zh/Object.md index 3cb39d29..0d17c71b 100644 --- a/orm/docs/zh/Object.md +++ b/orm/docs/zh/Object.md @@ -18,7 +18,7 @@ fmt.Println(o.Delete(user)) o := orm.NewOrm() user := User{Id: 1} -o.Read(&user) +err = o.Read(&user) if err == sql.ErrNoRows { fmt.Println("查询不到") diff --git a/orm/models.go b/orm/models.go index f0ed936a..d2fa9a99 100644 --- a/orm/models.go +++ b/orm/models.go @@ -16,7 +16,10 @@ const ( var ( errLog *log.Logger - modelCache = &_modelCache{cache: make(map[string]*modelInfo)} + modelCache = &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFN: make(map[string]*modelInfo), + } supportTag = map[string]int{ "null": 1, "blank": 1, @@ -47,9 +50,10 @@ func init() { type _modelCache struct { sync.RWMutex - orders []string - cache map[string]*modelInfo - done bool + orders []string + cache map[string]*modelInfo + cacheByFN map[string]*modelInfo + done bool } func (mc *_modelCache) all() map[string]*modelInfo { @@ -70,12 +74,16 @@ func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { mi, ok = mc.cache[table] + if ok == false { + mi, ok = mc.cacheByFN[table] + } return } func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { mii := mc.cache[table] mc.cache[table] = mi + mc.cacheByFN[mi.fullName] = mi if mii == nil { mc.orders = append(mc.orders, table) } diff --git a/orm/orm_test.go b/orm/orm_test.go index 8c3768f6..9e29ad3a 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -410,6 +410,15 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("status__in", 1, 2).Count() throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 2)) } func TestAll(t *testing.T) { @@ -684,5 +693,49 @@ func TestDelete(t *testing.T) { } func TestTransaction(t *testing.T) { + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var user User + user.UserName = names[0] + id, err := o.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) + + num, err := o.QueryTable("user").Filter("user_name", "slene").Update(Params{"user_name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Large, 0)) + + switch o.Driver().Type() { + case DR_MySQL: + id, err := o.Raw("INSERT INTO user (user_name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("user").Filter("user_name__in", &user).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 0)) + + err = o.Begin() + throwFail(t, err) + + user.UserName = "commit" + id, err = o.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("user").Filter("user_name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) }