mirror of
https://github.com/astaxie/beego.git
synced 2024-11-26 04:01:29 +00:00
commit
c2aeab78aa
104
orm/db.go
104
orm/db.go
@ -488,6 +488,110 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertOrUpdate a row
|
||||||
|
// If your primary key or unique column conflict will update
|
||||||
|
// If no will insert
|
||||||
|
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
|
args0 := ""
|
||||||
|
iouStr := ""
|
||||||
|
argsMap := map[string]string{}
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
|
iouStr = "ON DUPLICATE KEY UPDATE"
|
||||||
|
case DRPostgres:
|
||||||
|
if len(args) == 0 {
|
||||||
|
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||||
|
} else {
|
||||||
|
args0 = strings.ToLower(args[0])
|
||||||
|
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Get on the key-value pairs
|
||||||
|
for _, v := range args {
|
||||||
|
kv := strings.Split(v, "=")
|
||||||
|
if len(kv) == 2 {
|
||||||
|
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isMulti := false
|
||||||
|
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
marks := make([]string, len(names))
|
||||||
|
updateValues := make([]interface{}, 0)
|
||||||
|
updates := make([]string, len(names))
|
||||||
|
var conflitValue interface{}
|
||||||
|
for i, v := range names {
|
||||||
|
marks[i] = "?"
|
||||||
|
valueStr := argsMap[strings.ToLower(v)]
|
||||||
|
if v == args0 {
|
||||||
|
conflitValue = values[i]
|
||||||
|
}
|
||||||
|
if valueStr != "" {
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
|
updates[i] = v + "=" + valueStr
|
||||||
|
case DRPostgres:
|
||||||
|
if conflitValue != nil {
|
||||||
|
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
||||||
|
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0)
|
||||||
|
updateValues = append(updateValues, conflitValue)
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
updates[i] = v + "=?"
|
||||||
|
updateValues = append(updateValues, values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
values = append(values, updateValues...)
|
||||||
|
|
||||||
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
qmarks := strings.Join(marks, ", ")
|
||||||
|
qupdates := strings.Join(updates, ", ")
|
||||||
|
columns := strings.Join(names, sep)
|
||||||
|
|
||||||
|
multi := len(values) / len(names)
|
||||||
|
|
||||||
|
if isMulti {
|
||||||
|
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||||
|
}
|
||||||
|
//conflitValue maybe is a int,can`t use fmt.Sprintf
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
|
res, err := q.Exec(query, values...)
|
||||||
|
if err == nil {
|
||||||
|
if isMulti {
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
return res.LastInsertId()
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
row := q.QueryRow(query, values...)
|
||||||
|
var id int64
|
||||||
|
err = row.Scan(&id)
|
||||||
|
if err.Error() == `pq: syntax error at or near "ON"` {
|
||||||
|
err = fmt.Errorf("postgres version must 9.5 or higher")
|
||||||
|
}
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
// execute update sql dbQuerier with given struct reflect.Value.
|
// execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
|
13
orm/orm.go
13
orm/orm.go
@ -209,6 +209,19 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
|||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertOrUpdate data to database
|
||||||
|
func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||||
|
mi, ind := o.getMiInd(md, true)
|
||||||
|
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
o.setPk(mi, ind, id)
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
// update model to database.
|
// update model to database.
|
||||||
// cols set the columns those want to update.
|
// cols set the columns those want to update.
|
||||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||||
|
@ -2174,3 +2174,89 @@ func TestIgnoreCaseTag(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
|
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
|
||||||
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
|
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
|
||||||
}
|
}
|
||||||
|
func TestInsertOrUpdate(t *testing.T) {
|
||||||
|
RegisterModel(new(User))
|
||||||
|
user := User{UserName: "unique_username133", Status: 1, Password: "o"}
|
||||||
|
user1 := User{UserName: "unique_username133", Status: 2, Password: "o"}
|
||||||
|
user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"}
|
||||||
|
dORM.Insert(&user)
|
||||||
|
test := User{UserName: "unique_username133"}
|
||||||
|
fmt.Println(dORM.Driver().Name())
|
||||||
|
if dORM.Driver().Name() == "sqlite3" {
|
||||||
|
fmt.Println("sqlite3 is nonsupport")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//test1
|
||||||
|
_, err := dORM.InsertOrUpdate(&user1, "user_name")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs(user1.Status, test.Status))
|
||||||
|
}
|
||||||
|
//test2
|
||||||
|
_, err = dORM.InsertOrUpdate(&user2, "user_name")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs(user2.Status, test.Status))
|
||||||
|
throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password)))
|
||||||
|
}
|
||||||
|
//test3 +
|
||||||
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs(user2.Status+1, test.Status))
|
||||||
|
}
|
||||||
|
//test4 -
|
||||||
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status))
|
||||||
|
}
|
||||||
|
//test5 *
|
||||||
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status))
|
||||||
|
}
|
||||||
|
//test6 /
|
||||||
|
_, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
|
||||||
|
} else {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dORM.Read(&test, "user_name")
|
||||||
|
throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -53,6 +53,11 @@ type Ormer interface {
|
|||||||
// id, err = Ormer.Insert(user)
|
// id, err = Ormer.Insert(user)
|
||||||
// user must a pointer and Insert will set user's pk field
|
// user must a pointer and Insert will set user's pk field
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
|
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
|
||||||
|
// if colu type is integer : can use(+-*/), string : convert(colu,"value")
|
||||||
|
// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
|
||||||
|
// if colu type is integer : can use(+-*/), string : colu || "value"
|
||||||
|
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
|
||||||
// insert some models to database
|
// insert some models to database
|
||||||
InsertMulti(bulk int, mds interface{}) (int64, error)
|
InsertMulti(bulk int, mds interface{}) (int64, error)
|
||||||
// update model to database.
|
// update model to database.
|
||||||
@ -391,6 +396,7 @@ type txEnder interface {
|
|||||||
type dbBaser interface {
|
type dbBaser interface {
|
||||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
|
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
||||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
|
Loading…
Reference in New Issue
Block a user