diff --git a/orm/db.go b/orm/db.go index 9964e263..c4b0e046 100644 --- a/orm/db.go +++ b/orm/db.go @@ -488,6 +488,112 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s return id, err } +//insert or update a row +//If your primary key or unique column conflict will update +//if no will insert +func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, dn string, args ...string) (int64, error) { + iouStr := "" + mysql := "mysql" + postgres := "postgres" + argsMap := map[string]string{} + args0 := "" + if dn == mysql { + iouStr = "ON DUPLICATE KEY UPDATE" + } else if dn == postgres { + if len(args) == 0 || (len(strings.Split(args0, "=")) != 1) { + return 0, fmt.Errorf("`%s` use insert or update must have a conflict column arg in first", dn) + } else { + args0 = args[0] + iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) + } + } else { + return 0, fmt.Errorf("`%s` nonsupport insert or update in beego", dn) + } + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[kv[0]] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + var conflitValue interface{} + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[v] + if v == args0 { + conflitValue = values[i] + } + if valueStr != "" { + switch dn { + case mysql: + updates[i] = v + "=" + valueStr + break + case postgres: + if conflitValue != nil { + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args[0]) + updateValues = append(updateValues, conflitValue) + } else { + return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args[0], v) + } + break + } + } else { + updates[i] = v + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + if err.Error() == `pq: syntax error at or near "ON"` { + err = fmt.Errorf("postgres version must 9.5 or higher") + } + return id, err +} + // execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) diff --git a/orm/orm.go b/orm/orm.go index 5e43ae59..fe189037 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -209,6 +209,19 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { return cnt, nil } +//insert or update data to database +func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias.TZ, o.alias.DriverName, colConflitAndArgs...) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + // update model to database. // cols set the columns those want to update. func (o *orm) Update(md interface{}, cols ...string) (int64, error) { diff --git a/orm/orm_test.go b/orm/orm_test.go index b5973448..aac9fef8 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -2174,3 +2174,76 @@ func TestIgnoreCaseTag(t *testing.T) { throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) } +func TestInsertOrUpdate(t *testing.T) { + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + //test1 普通操作 + _, err := dORM.InsertOrUpdate(&user1, "UserName") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + } + test := User{UserName: "unique_username133"} + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + //test2 普通操作 + _, err = dORM.InsertOrUpdate(&user2, "UserName") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + //test3 数字 + 操作 + _, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status+1") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + //test4 数字 - 操作 + _, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status-1") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + //test5 数字 * 操作 + _, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status*3") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + //test6 数字 / 操作 + _, err = dORM.InsertOrUpdate(&user2, "UserName", "Status=Status/3") + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + fmt.Println(err) + } else { + throwFailNow(t, err) + time.Sleep(time.Second * 1) + dORM.Read(&test, "UserName") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/orm/types.go b/orm/types.go index cb55e71a..7864e315 100644 --- a/orm/types.go +++ b/orm/types.go @@ -53,6 +53,11 @@ type Ormer interface { // id, err = Ormer.Insert(user) // user must a pointer and Insert will set user's pk field Insert(interface{}) (int64, error) + //mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + //if colu type is integer : can use(+-*/), string : convert(colu,"value") + //postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + //if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) // insert some models to database InsertMulti(bulk int, mds interface{}) (int64, error) // update model to database. @@ -391,6 +396,7 @@ type txEnder interface { type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *time.Location, string, ...string) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)