1
0
mirror of https://github.com/astaxie/beego.git synced 2024-12-23 08:00:50 +00:00

orm insert or update

This commit is contained in:
“fudali113” 2016-07-20 14:37:05 +08:00
parent d11823548b
commit 4b8ecced83
4 changed files with 198 additions and 0 deletions

106
orm/db.go
View File

@ -488,6 +488,112 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
return id, err
}
//insert or update a row
//If your primary key or unique column conflict will update
//if no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, dn string, args ...string) (int64, error) {
iouStr := ""
mysql := "mysql"
postgres := "postgres"
argsMap := map[string]string{}
args0 := ""
if dn == mysql {
iouStr = "ON DUPLICATE KEY UPDATE"
} else if dn == postgres {
if len(args) == 0 || (len(strings.Split(args0, "=")) != 1) {
return 0, fmt.Errorf("`%s` use insert or update must have a conflict column arg in first", dn)
} else {
args0 = args[0]
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
}
} else {
return 0, fmt.Errorf("`%s` nonsupport insert or update in beego", dn)
}
//Get on the key-value pairs
for _, v := range args {
kv := strings.Split(v, "=")
if len(kv) == 2 {
argsMap[kv[0]] = kv[1]
}
}
isMulti := false
names := make([]string, 0, len(mi.fields.dbcols)-1)
Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
updateValues := make([]interface{}, 0)
updates := make([]string, len(names))
var conflitValue interface{}
for i, v := range names {
marks[i] = "?"
valueStr := argsMap[v]
if v == args0 {
conflitValue = values[i]
}
if valueStr != "" {
switch dn {
case mysql:
updates[i] = v + "=" + valueStr
break
case postgres:
if conflitValue != nil {
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args[0])
updateValues = append(updateValues, conflitValue)
} else {
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args[0], v)
}
break
}
} else {
updates[i] = v + "=?"
updateValues = append(updateValues, values[i])
}
}
values = append(values, updateValues...)
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
//conflitValue maybe is a int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
}
return 0, err
}
row := q.QueryRow(query, values...)
var id int64
err = row.Scan(&id)
if err.Error() == `pq: syntax error at or near "ON"` {
err = fmt.Errorf("postgres version must 9.5 or higher")
}
return id, err
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)

View File

@ -209,6 +209,19 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
return cnt, nil
}
//insert or update data to database
func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias.TZ, o.alias.DriverName, colConflitAndArgs...)
if err != nil {
return id, err
}
o.setPk(mi, ind, id)
return id, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {

View File

@ -2174,3 +2174,76 @@ func TestIgnoreCaseTag(t *testing.T) {
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
}
func TestInsertOrUpdate(t *testing.T) {
user := User{UserName: "unique_username133", Status: 1, Password: "o"}
user1 := User{UserName: "unique_username133", Status: 2, Password: "o"}
user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"}
dORM.Insert(&user)
fmt.Println(dORM.Driver().Name())
if dORM.Driver().Name() == "sqlite3" {
fmt.Println("sqlite3 is nonsupport")
return
}
//test1 普通操作
_, err := dORM.InsertOrUpdate(&user1, "UserName")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
}
test := User{UserName: "unique_username133"}
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs(user1.Status, test.Status))
//test2 普通操作
_, err = dORM.InsertOrUpdate(&user2, "UserName")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs(user2.Status, test.Status))
throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password)))
}
//test3 数字 + 操作
_, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status+1")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs(user2.Status+1, test.Status))
}
//test4 数字 - 操作
_, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status-1")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status))
}
//test5 数字 * 操作
_, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status*3")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status))
}
//test6 数字 / 操作
_, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status/3")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
fmt.Println(err)
} else {
throwFailNow(t, err)
time.Sleep(time.Second * 1)
dORM.Read(&test, "UserName")
throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status))
}
}

View File

@ -53,6 +53,11 @@ type Ormer interface {
// id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
//mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
//if colu type is integer : can use(+-*/), string : convert(colu,"value")
//postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
//if colu type is integer : can use(+-*/), string : colu || "value"
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
// insert some models to database
InsertMulti(bulk int, mds interface{}) (int64, error)
// update model to database.
@ -391,6 +396,7 @@ type txEnder interface {
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *time.Location, string, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)