From bf3830b6f0bef45eb20bef9b656af6bce62f4156 Mon Sep 17 00:00:00 2001 From: slene Date: Wed, 9 Oct 2013 20:28:54 +0800 Subject: [PATCH] orm add atomic set value --- orm/db.go | 31 ++++++++++++++++++++++++++----- orm/models_test.go | 1 + orm/orm_queryset.go | 30 ++++++++++++++++++++++++++++++ orm/orm_test.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/orm/db.go b/orm/db.go index f639a27b..5d2ad50e 100644 --- a/orm/db.go +++ b/orm/db.go @@ -382,17 +382,38 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con join := tables.getJoinSql() - var query string + var query, T, cols string Q := d.ins.TableQuote() if d.ins.SupportUpdateJoin() { - cols := strings.Join(columns, fmt.Sprintf("%s = ?, T0.%s", Q, Q)) - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET T0.%s%s%s = ? %s", Q, mi.table, Q, join, Q, cols, Q, where) + T = "T0." + } + + for i, v := range columns { + col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) + if c, ok := values[i].(colValue); ok { + switch c.opt { + case Col_Add: + cols += col + " = " + col + " + ? " + case Col_Minus: + cols += col + " = " + col + " - ? " + case Col_Multiply: + cols += col + " = " + col + " * ? " + case Col_Except: + cols += col + " = " + col + " / ? " + } + values[i] = c.value + } else { + cols += col + " = ? " + } + } + + if d.ins.SupportUpdateJoin() { + query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, cols, where) } else { - cols := strings.Join(columns, fmt.Sprintf("%s = ?, %s", Q, Q)) supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) - query = fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s IN ( %s )", Q, mi.table, Q, Q, cols, Q, Q, mi.fields.pk.column, Q, supQuery) + query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, cols, Q, mi.fields.pk.column, Q, supQuery) } d.ins.ReplaceMarks(&query) diff --git a/orm/models_test.go b/orm/models_test.go index 1e1420a9..1891393a 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -77,6 +77,7 @@ type User struct { Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` Posts []*Post `orm:"reverse(many)" json:"-"` ShouldSkip string `orm:"-"` + Nums int } func (u *User) TableIndex() [][]string { diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 41c63d8c..c1788f5b 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -4,6 +4,36 @@ import ( "fmt" ) +type colValue struct { + value int64 + opt operator +} + +type operator int + +const ( + Col_Add operator = iota + Col_Minus + Col_Multiply + Col_Except +) + +func ColValue(opt operator, value interface{}) interface{} { + switch opt { + case Col_Add, Col_Minus, Col_Multiply, Col_Except: + default: + panic(fmt.Errorf("orm.ColValue wrong operator")) + } + v, err := StrTo(ToStr(value)).Int64() + if err != nil { + panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) + } + var val colValue + val.value = v + val.opt = opt + return val +} + type querySet struct { mi *modelInfo cond *Condition diff --git a/orm/orm_test.go b/orm/orm_test.go index 8dc430af..37baff9d 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1230,6 +1230,35 @@ func TestUpdate(t *testing.T) { }) throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(Col_Add, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(Col_Minus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(Col_Multiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(Col_Except, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) } func TestDelete(t *testing.T) {