mirror of
https://github.com/astaxie/beego.git
synced 2024-11-10 18:20:55 +00:00
Merge pull request #1826 from miraclesu/feature/orm_auto
orm: support insert a specified value to auto field
This commit is contained in:
commit
699de2ae75
@ -41,7 +41,7 @@ func (ec *errorTestController) Get() {
|
||||
|
||||
func TestErrorCode_01(t *testing.T) {
|
||||
registerDefaultErrorHandler()
|
||||
for k, _ := range ErrorMaps {
|
||||
for k := range ErrorMaps {
|
||||
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
76
orm/db.go
76
orm/db.go
@ -71,12 +71,12 @@ type dbBase struct {
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
// get struct columns values as interface slice.
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
||||
var columns []string
|
||||
|
||||
if names != nil {
|
||||
columns = *names
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
|
||||
if names == nil {
|
||||
ns := make([]string, 0, len(cols))
|
||||
names = &ns
|
||||
}
|
||||
values = make([]interface{}, 0, len(cols))
|
||||
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
@ -90,18 +90,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
||||
}
|
||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
columns = append(columns, column)
|
||||
// ignore empty value auto field
|
||||
if insert && fi.auto {
|
||||
if fi.fieldType&IsPositiveIntegerField > 0 {
|
||||
if vu, ok := value.(uint64); !ok || vu == 0 {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if vu, ok := value.(int64); !ok || vu == 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
autoFields = append(autoFields, fi.column)
|
||||
}
|
||||
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
*names = columns
|
||||
*names, values = append(*names, column), append(values, value)
|
||||
}
|
||||
|
||||
return
|
||||
@ -273,7 +279,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
|
||||
// insert struct with prepared statement and given struct reflect value.
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -300,7 +306,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols = make([]string, 0, len(cols))
|
||||
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||
args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -349,13 +355,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
|
||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
names := make([]string, 0, len(mi.fields.dbcols))
|
||||
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return d.InsertValue(q, mi, false, names, values)
|
||||
id, err := d.InsertValue(q, mi, false, names, values)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(autoFields) > 0 {
|
||||
err = d.ins.setval(q, mi, autoFields)
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// multi-insert sql with given slice struct reflect.Value.
|
||||
@ -369,7 +383,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
||||
|
||||
// typ := reflect.Indirect(mi.addrField).Type()
|
||||
|
||||
length := sind.Len()
|
||||
length, autoFields := sind.Len(), make([]string, 0, 1)
|
||||
|
||||
for i := 1; i <= length; i++ {
|
||||
|
||||
@ -381,16 +395,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
||||
// }
|
||||
|
||||
if i == 1 {
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
var (
|
||||
vus []interface{}
|
||||
err error
|
||||
)
|
||||
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
values = make([]interface{}, bulk*len(vus))
|
||||
nums += copy(values, vus)
|
||||
|
||||
} else {
|
||||
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
@ -412,7 +428,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
var err error
|
||||
if len(autoFields) > 0 {
|
||||
err = d.ins.setval(q, mi, autoFields)
|
||||
}
|
||||
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
// execute insert sql with given struct and given values.
|
||||
@ -472,7 +493,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
setNames = make([]string, 0, len(cols))
|
||||
}
|
||||
|
||||
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||
setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -1562,6 +1583,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// sync auto key
|
||||
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// convert time from db.
|
||||
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
||||
*t = t.In(tz)
|
||||
|
@ -135,6 +135,25 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// sync auto key
|
||||
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||
if len(autoFields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
for _, name := range autoFields {
|
||||
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
|
||||
mi.table, name,
|
||||
Q, name, Q,
|
||||
Q, mi.table, Q)
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// show table sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||
|
@ -2016,6 +2016,44 @@ func TestIntegerPk(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertAuto(t *testing.T) {
|
||||
u := &User{
|
||||
UserName: "autoPre",
|
||||
Email: "autoPre@gmail.com",
|
||||
}
|
||||
|
||||
id, err := dORM.Insert(u)
|
||||
throwFail(t, err)
|
||||
|
||||
id += 100
|
||||
su := &User{
|
||||
ID: int(id),
|
||||
UserName: "auto",
|
||||
Email: "auto@gmail.com",
|
||||
}
|
||||
|
||||
nid, err := dORM.Insert(su)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(nid, id))
|
||||
|
||||
users := []User{
|
||||
{ID: int(id + 100), UserName: "auto_100"},
|
||||
{ID: int(id + 110), UserName: "auto_110"},
|
||||
{ID: int(id + 120), UserName: "auto_120"},
|
||||
}
|
||||
num, err := dORM.InsertMulti(100, users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 3))
|
||||
|
||||
u = &User{
|
||||
UserName: "auto_121",
|
||||
}
|
||||
|
||||
nid, err = dORM.Insert(u)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(nid, id+120+1))
|
||||
}
|
||||
|
||||
func TestUintPk(t *testing.T) {
|
||||
name := "go"
|
||||
u := &UintPk{
|
||||
|
@ -420,4 +420,5 @@ type dbBaser interface {
|
||||
ShowColumnsQuery(string) string
|
||||
IndexExists(dbQuerier, string, string) bool
|
||||
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
||||
setval(dbQuerier, *modelInfo, []string) error
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user