mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 22:50:55 +00:00
Optimize the code logic
This commit is contained in:
parent
bf17558d06
commit
182a21172f
37
orm/db.go
37
orm/db.go
@ -492,30 +492,28 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
iouStr := ""
|
|
||||||
mysql := DRMySQL
|
|
||||||
postgres := DRPostgres
|
|
||||||
driver := a.Driver
|
|
||||||
argsMap := map[string]string{}
|
|
||||||
args0 := ""
|
args0 := ""
|
||||||
if driver == mysql {
|
iouStr := ""
|
||||||
|
argsMap := map[string]string{}
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
iouStr = "ON DUPLICATE KEY UPDATE"
|
iouStr = "ON DUPLICATE KEY UPDATE"
|
||||||
} else if driver == postgres {
|
case DRPostgres:
|
||||||
if len(args) == 0 || (len(strings.Split(args0, "=")) != 1) {
|
if len(args) == 0 {
|
||||||
return 0, fmt.Errorf("`%s` use insert or update must have a conflict column arg in first", a.DriverName)
|
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||||
} else {
|
} else {
|
||||||
args0 = strings.ToLower(args[0])
|
args0 = strings.ToLower(args[0])
|
||||||
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
||||||
}
|
}
|
||||||
} else {
|
default:
|
||||||
return 0, fmt.Errorf("`%s` nonsupport insert or update in beego", a.DriverName)
|
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
//Get on the key-value pairs
|
//Get on the key-value pairs
|
||||||
for _, v := range args {
|
for _, v := range args {
|
||||||
kv := strings.Split(v, "=")
|
kv := strings.Split(v, "=")
|
||||||
if len(kv) == 2 {
|
if len(kv) == 2 {
|
||||||
k := strings.ToLower(kv[0])
|
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||||
argsMap[k] = kv[1]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -534,17 +532,15 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
var conflitValue interface{}
|
var conflitValue interface{}
|
||||||
for i, v := range names {
|
for i, v := range names {
|
||||||
marks[i] = "?"
|
marks[i] = "?"
|
||||||
vtl := strings.ToLower(v)
|
valueStr := argsMap[v]
|
||||||
valueStr := argsMap[vtl]
|
if v == args0 {
|
||||||
if vtl == args0 {
|
|
||||||
conflitValue = values[i]
|
conflitValue = values[i]
|
||||||
}
|
}
|
||||||
if valueStr != "" {
|
if valueStr != "" {
|
||||||
switch driver {
|
switch a.Driver {
|
||||||
case mysql:
|
case DRMySQL:
|
||||||
updates[i] = v + "=" + valueStr
|
updates[i] = v + "=" + valueStr
|
||||||
break
|
case DRPostgres:
|
||||||
case postgres:
|
|
||||||
if conflitValue != nil {
|
if conflitValue != nil {
|
||||||
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
||||||
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0)
|
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0)
|
||||||
@ -552,7 +548,6 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
} else {
|
} else {
|
||||||
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
||||||
}
|
}
|
||||||
break
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
updates[i] = v + "=?"
|
updates[i] = v + "=?"
|
||||||
|
@ -2188,8 +2188,9 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test1
|
//test1
|
||||||
_, err := dORM.InsertOrUpdate(&user1, "user_name")
|
_, err := dORM.InsertOrUpdate(&user1, "user_name")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
return
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
dORM.Read(&test, "user_name")
|
dORM.Read(&test, "user_name")
|
||||||
@ -2197,7 +2198,7 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test2
|
//test2
|
||||||
_, err = dORM.InsertOrUpdate(&user2, "user_name")
|
_, err = dORM.InsertOrUpdate(&user2, "user_name")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
@ -2207,7 +2208,7 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test3 +
|
//test3 +
|
||||||
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status+1")
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status+1")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
@ -2216,7 +2217,7 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test4 -
|
//test4 -
|
||||||
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status-1")
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status-1")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
@ -2225,7 +2226,7 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test5 *
|
//test5 *
|
||||||
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status*3")
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status*3")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
@ -2234,7 +2235,7 @@ func TestInsertOrUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
//test6 /
|
//test6 /
|
||||||
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3")
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3")
|
||||||
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport insert or update in beego") {
|
if err != nil && (err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
} else {
|
} else {
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
|
Loading…
Reference in New Issue
Block a user