1
0
mirror of https://github.com/astaxie/beego.git synced 2024-06-02 13:23:27 +00:00

Optimize the code logic

This commit is contained in:
“fudali113” 2016-07-26 11:15:59 +08:00
parent bf17558d06
commit 182a21172f
2 changed files with 23 additions and 27 deletions

View File

@ -492,30 +492,28 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
// If your primary key or unique column conflict will update // If your primary key or unique column conflict will update
// If no will insert // If no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
iouStr := ""
mysql := DRMySQL
postgres := DRPostgres
driver := a.Driver
argsMap := map[string]string{}
args0 := "" args0 := ""
if driver == mysql { iouStr := ""
argsMap := map[string]string{}
switch a.Driver {
case DRMySQL:
iouStr = "ON DUPLICATE KEY UPDATE" iouStr = "ON DUPLICATE KEY UPDATE"
} else if driver == postgres { case DRPostgres:
if len(args) == 0 || (len(strings.Split(args0, "=")) != 1) { if len(args) == 0 {
return 0, fmt.Errorf("`%s` use insert or update must have a conflict column arg in first", a.DriverName) return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
} else { } else {
args0 = strings.ToLower(args[0]) args0 = strings.ToLower(args[0])
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
} }
} else { default:
return 0, fmt.Errorf("`%s` nonsupport insert or update in beego", a.DriverName) return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
} }
//Get on the key-value pairs //Get on the key-value pairs
for _, v := range args { for _, v := range args {
kv := strings.Split(v, "=") kv := strings.Split(v, "=")
if len(kv) == 2 { if len(kv) == 2 {
k := strings.ToLower(kv[0]) argsMap[strings.ToLower(kv[0])] = kv[1]
argsMap[k] = kv[1]
} }
} }
@ -534,17 +532,15 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
var conflitValue interface{} var conflitValue interface{}
for i, v := range names { for i, v := range names {
marks[i] = "?" marks[i] = "?"
vtl := strings.ToLower(v) valueStr := argsMap[v]
valueStr := argsMap[vtl] if v == args0 {
if vtl == args0 {
conflitValue = values[i] conflitValue = values[i]
} }
if valueStr != "" { if valueStr != "" {
switch driver { switch a.Driver {
case mysql: case DRMySQL:
updates[i] = v + "=" + valueStr updates[i] = v + "=" + valueStr
break case DRPostgres:
case postgres:
if conflitValue != nil { if conflitValue != nil {
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0)
@ -552,7 +548,6 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
} else { } else {
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
} }
break
} }
} else { } else {
updates[i] = v + "=?" updates[i] = v + "=?"

View File

@ -2188,8 +2188,9 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test1 //test1
_, err := dORM.InsertOrUpdate(&user1, "user_name") _, err := dORM.InsertOrUpdate(&user1, "user_name")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
return
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)
dORM.Read(&test, "user_name") dORM.Read(&test, "user_name")
@ -2197,7 +2198,7 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test2 //test2
_, err = dORM.InsertOrUpdate(&user2, "user_name") _, err = dORM.InsertOrUpdate(&user2, "user_name")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)
@ -2207,7 +2208,7 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test3 + //test3 +
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status+1") _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status+1")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)
@ -2216,7 +2217,7 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test4 - //test4 -
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status-1") _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status-1")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)
@ -2225,7 +2226,7 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test5 * //test5 *
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status*3") _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status*3")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)
@ -2234,7 +2235,7 @@ func TestInsertOrUpdate(t *testing.T) {
} }
//test6 / //test6 /
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3")
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") { if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
fmt.Println(err) fmt.Println(err)
} else { } else {
throwFailNow(t, err) throwFailNow(t, err)