mirror of
https://github.com/astaxie/beego.git
synced 2024-11-29 16:21:28 +00:00
Merge pull request #1826 from miraclesu/feature/orm_auto
orm: support insert a specified value to auto field
This commit is contained in:
commit
699de2ae75
@ -41,7 +41,7 @@ func (ec *errorTestController) Get() {
|
|||||||
|
|
||||||
func TestErrorCode_01(t *testing.T) {
|
func TestErrorCode_01(t *testing.T) {
|
||||||
registerDefaultErrorHandler()
|
registerDefaultErrorHandler()
|
||||||
for k, _ := range ErrorMaps {
|
for k := range ErrorMaps {
|
||||||
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
|
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
76
orm/db.go
76
orm/db.go
@ -71,12 +71,12 @@ type dbBase struct {
|
|||||||
var _ dbBaser = new(dbBase)
|
var _ dbBaser = new(dbBase)
|
||||||
|
|
||||||
// get struct columns values as interface slice.
|
// get struct columns values as interface slice.
|
||||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
|
||||||
var columns []string
|
if names == nil {
|
||||||
|
ns := make([]string, 0, len(cols))
|
||||||
if names != nil {
|
names = &ns
|
||||||
columns = *names
|
|
||||||
}
|
}
|
||||||
|
values = make([]interface{}, 0, len(cols))
|
||||||
|
|
||||||
for _, column := range cols {
|
for _, column := range cols {
|
||||||
var fi *fieldInfo
|
var fi *fieldInfo
|
||||||
@ -90,18 +90,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
|||||||
}
|
}
|
||||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if names != nil {
|
// ignore empty value auto field
|
||||||
columns = append(columns, column)
|
if insert && fi.auto {
|
||||||
|
if fi.fieldType&IsPositiveIntegerField > 0 {
|
||||||
|
if vu, ok := value.(uint64); !ok || vu == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if vu, ok := value.(int64); !ok || vu == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
autoFields = append(autoFields, fi.column)
|
||||||
}
|
}
|
||||||
|
|
||||||
values = append(values, value)
|
*names, values = append(*names, column), append(values, value)
|
||||||
}
|
|
||||||
|
|
||||||
if names != nil {
|
|
||||||
*names = columns
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -273,7 +279,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
|
|
||||||
// insert struct with prepared statement and given struct reflect value.
|
// insert struct with prepared statement and given struct reflect value.
|
||||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -300,7 +306,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
if len(cols) > 0 {
|
if len(cols) > 0 {
|
||||||
var err error
|
var err error
|
||||||
whereCols = make([]string, 0, len(cols))
|
whereCols = make([]string, 0, len(cols))
|
||||||
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -349,13 +355,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
|
|
||||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
names := make([]string, 0, len(mi.fields.dbcols))
|
||||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.InsertValue(q, mi, false, names, values)
|
id, err := d.InsertValue(q, mi, false, names, values)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(autoFields) > 0 {
|
||||||
|
err = d.ins.setval(q, mi, autoFields)
|
||||||
|
}
|
||||||
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-insert sql with given slice struct reflect.Value.
|
// multi-insert sql with given slice struct reflect.Value.
|
||||||
@ -369,7 +383,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
|
|
||||||
// typ := reflect.Indirect(mi.addrField).Type()
|
// typ := reflect.Indirect(mi.addrField).Type()
|
||||||
|
|
||||||
length := sind.Len()
|
length, autoFields := sind.Len(), make([]string, 0, 1)
|
||||||
|
|
||||||
for i := 1; i <= length; i++ {
|
for i := 1; i <= length; i++ {
|
||||||
|
|
||||||
@ -381,16 +395,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
if i == 1 {
|
if i == 1 {
|
||||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
var (
|
||||||
|
vus []interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
values = make([]interface{}, bulk*len(vus))
|
values = make([]interface{}, bulk*len(vus))
|
||||||
nums += copy(values, vus)
|
nums += copy(values, vus)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
|
||||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
@ -412,7 +428,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cnt, nil
|
var err error
|
||||||
|
if len(autoFields) > 0 {
|
||||||
|
err = d.ins.setval(q, mi, autoFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cnt, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute insert sql with given struct and given values.
|
// execute insert sql with given struct and given values.
|
||||||
@ -472,7 +493,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
setNames = make([]string, 0, len(cols))
|
setNames = make([]string, 0, len(cols))
|
||||||
}
|
}
|
||||||
|
|
||||||
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -1562,6 +1583,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sync auto key
|
||||||
|
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// convert time from db.
|
// convert time from db.
|
||||||
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
||||||
*t = t.In(tz)
|
*t = t.In(tz)
|
||||||
|
@ -135,6 +135,25 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sync auto key
|
||||||
|
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
|
if len(autoFields) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
for _, name := range autoFields {
|
||||||
|
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
|
||||||
|
mi.table, name,
|
||||||
|
Q, name, Q,
|
||||||
|
Q, mi.table, Q)
|
||||||
|
if _, err := db.Exec(query); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// show table sql for postgresql.
|
// show table sql for postgresql.
|
||||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||||
|
@ -2016,6 +2016,44 @@ func TestIntegerPk(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInsertAuto(t *testing.T) {
|
||||||
|
u := &User{
|
||||||
|
UserName: "autoPre",
|
||||||
|
Email: "autoPre@gmail.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := dORM.Insert(u)
|
||||||
|
throwFail(t, err)
|
||||||
|
|
||||||
|
id += 100
|
||||||
|
su := &User{
|
||||||
|
ID: int(id),
|
||||||
|
UserName: "auto",
|
||||||
|
Email: "auto@gmail.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
nid, err := dORM.Insert(su)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(nid, id))
|
||||||
|
|
||||||
|
users := []User{
|
||||||
|
{ID: int(id + 100), UserName: "auto_100"},
|
||||||
|
{ID: int(id + 110), UserName: "auto_110"},
|
||||||
|
{ID: int(id + 120), UserName: "auto_120"},
|
||||||
|
}
|
||||||
|
num, err := dORM.InsertMulti(100, users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
|
||||||
|
u = &User{
|
||||||
|
UserName: "auto_121",
|
||||||
|
}
|
||||||
|
|
||||||
|
nid, err = dORM.Insert(u)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(nid, id+120+1))
|
||||||
|
}
|
||||||
|
|
||||||
func TestUintPk(t *testing.T) {
|
func TestUintPk(t *testing.T) {
|
||||||
name := "go"
|
name := "go"
|
||||||
u := &UintPk{
|
u := &UintPk{
|
||||||
|
@ -420,4 +420,5 @@ type dbBaser interface {
|
|||||||
ShowColumnsQuery(string) string
|
ShowColumnsQuery(string) string
|
||||||
IndexExists(dbQuerier, string, string) bool
|
IndexExists(dbQuerier, string, string) bool
|
||||||
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
||||||
|
setval(dbQuerier, *modelInfo, []string) error
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user