1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 18:20:55 +00:00

orm: fix postgres sequence value

This commit is contained in:
miraclesu 2016-03-27 15:06:57 +08:00
parent 3ca44071e6
commit 1794c52d65
3 changed files with 55 additions and 12 deletions

View File

@ -71,7 +71,7 @@ type dbBase struct {
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
// get struct columns values as interface slice. // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
if names == nil { if names == nil {
ns := make([]string, 0, len(cols)) ns := make([]string, 0, len(cols))
names = &ns names = &ns
@ -90,11 +90,11 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// ignore empty value auto field // ignore empty value auto field
if fi.auto { if insert && fi.auto {
if fi.fieldType&IsPositiveIntegerField > 0 { if fi.fieldType&IsPositiveIntegerField > 0 {
if vu, ok := value.(uint64); !ok || vu == 0 { if vu, ok := value.(uint64); !ok || vu == 0 {
continue continue
@ -104,6 +104,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
continue continue
} }
} }
autoFields = append(autoFields, fi.column)
} }
*names, values = append(*names, column), append(values, value) *names, values = append(*names, column), append(values, value)
@ -278,7 +279,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
// insert struct with prepared statement and given struct reflect value. // insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -305,7 +306,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols = make([]string, 0, len(cols)) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -355,12 +356,20 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// execute insert sql dbQuerier with given struct reflect.Value. // execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols)) names := make([]string, 0, len(mi.fields.dbcols))
values, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, false, names, values) id, err := d.InsertValue(q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return id, err
} }
// multi-insert sql with given slice struct reflect.Value. // multi-insert sql with given slice struct reflect.Value.
@ -374,7 +383,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// typ := reflect.Indirect(mi.addrField).Type() // typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len() length, autoFields := sind.Len(), make([]string, 0, 1)
for i := 1; i <= length; i++ { for i := 1; i <= length; i++ {
@ -386,14 +395,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// } // }
if i == 1 { if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) var (
vus []interface{}
err error
)
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
values = make([]interface{}, bulk*len(vus)) values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus) nums += copy(values, vus)
} else { } else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -415,7 +428,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
} }
} }
return cnt, nil var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return cnt, err
} }
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
@ -475,7 +493,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
setNames = make([]string, 0, len(cols)) setNames = make([]string, 0, len(cols))
} }
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -1565,6 +1583,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
// convert time from db. // convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)

View File

@ -135,6 +135,25 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
return true return true
} }
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
}
// show table sql for postgresql. // show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string { func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"

View File

@ -420,4 +420,5 @@ type dbBaser interface {
ShowColumnsQuery(string) string ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
} }