diff --git a/orm/db.go b/orm/db.go index 613fc8a9..92b6cfe8 100644 --- a/orm/db.go +++ b/orm/db.go @@ -492,30 +492,28 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // If your primary key or unique column conflict will update // If no will insert func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - iouStr := "" - mysql := DRMySQL - postgres := DRPostgres - driver := a.Driver - argsMap := map[string]string{} args0 := "" - if driver == mysql { + iouStr := "" + argsMap := map[string]string{} + switch a.Driver { + case DRMySQL: iouStr = "ON DUPLICATE KEY UPDATE" - } else if driver == postgres { - if len(args) == 0 || (len(strings.Split(args0, "=")) != 1) { - return 0, fmt.Errorf("`%s` use insert or update must have a conflict column arg in first", a.DriverName) + case DRPostgres: + if len(args) == 0 { + return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) } else { args0 = strings.ToLower(args[0]) iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) } - } else { - return 0, fmt.Errorf("`%s` nonsupport insert or update in beego", a.DriverName) + default: + return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) } + //Get on the key-value pairs for _, v := range args { kv := strings.Split(v, "=") if len(kv) == 2 { - k := strings.ToLower(kv[0]) - argsMap[k] = kv[1] + argsMap[strings.ToLower(kv[0])] = kv[1] } } @@ -534,17 +532,15 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a var conflitValue interface{} for i, v := range names { marks[i] = "?" - vtl := strings.ToLower(v) - valueStr := argsMap[vtl] - if vtl == args0 { + valueStr := argsMap[v] + if v == args0 { conflitValue = values[i] } if valueStr != "" { - switch driver { - case mysql: + switch a.Driver { + case DRMySQL: updates[i] = v + "=" + valueStr - break - case postgres: + case DRPostgres: if conflitValue != nil { //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) @@ -552,7 +548,6 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } else { return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) } - break } } else { updates[i] = v + "=?" diff --git a/orm/orm_test.go b/orm/orm_test.go index 5b44f286..c0e7dacd 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -2188,8 +2188,9 @@ func TestInsertOrUpdate(t *testing.T) { } //test1 _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) + return } else { throwFailNow(t, err) dORM.Read(&test, "user_name") @@ -2197,7 +2198,7 @@ func TestInsertOrUpdate(t *testing.T) { } //test2 _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) } else { throwFailNow(t, err) @@ -2207,7 +2208,7 @@ func TestInsertOrUpdate(t *testing.T) { } //test3 + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status+1") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) } else { throwFailNow(t, err) @@ -2216,7 +2217,7 @@ func TestInsertOrUpdate(t *testing.T) { } //test4 - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status-1") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) } else { throwFailNow(t, err) @@ -2225,7 +2226,7 @@ func TestInsertOrUpdate(t *testing.T) { } //test5 * _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status*3") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) } else { throwFailNow(t, err) @@ -2234,7 +2235,7 @@ func TestInsertOrUpdate(t *testing.T) { } //test6 / _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { + if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { fmt.Println(err) } else { throwFailNow(t, err)