diff --git a/orm/db.go b/orm/db.go index 8c3b82c2..efc90e6e 100644 --- a/orm/db.go +++ b/orm/db.go @@ -71,7 +71,7 @@ type dbBase struct { var _ dbBaser = new(dbBase) // get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { if names == nil { ns := make([]string, 0, len(cols)) names = &ns @@ -90,11 +90,11 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) if err != nil { - return nil, err + return nil, nil, err } // ignore empty value auto field - if fi.auto { + if insert && fi.auto { if fi.fieldType&IsPositiveIntegerField > 0 { if vu, ok := value.(uint64); !ok || vu == 0 { continue @@ -104,6 +104,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, continue } } + autoFields = append(autoFields, fi.column) } *names, values = append(*names, column), append(values, value) @@ -278,7 +279,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, // insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err } @@ -305,7 +306,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo if len(cols) > 0 { var err error whereCols = make([]string, 0, len(cols)) - args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) if err != nil { return err } @@ -355,12 +356,20 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo // execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) - values, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - return d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err } // multi-insert sql with given slice struct reflect.Value. @@ -374,7 +383,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // typ := reflect.Indirect(mi.addrField).Type() - length := sind.Len() + length, autoFields := sind.Len(), make([]string, 0, 1) for i := 1; i <= length; i++ { @@ -386,14 +395,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // } if i == 1 { - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return cnt, err } values = make([]interface{}, bulk*len(vus)) nums += copy(values, vus) } else { - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) if err != nil { return cnt, err } @@ -415,7 +428,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } } - return cnt, nil + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err } // execute insert sql with given struct and given values. @@ -475,7 +493,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. setNames = make([]string, 0, len(cols)) } - setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) if err != nil { return 0, err } @@ -1565,6 +1583,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + // convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) diff --git a/orm/db_postgres.go b/orm/db_postgres.go index be4cd0bc..8fbcb88d 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -135,6 +135,25 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { return true } +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil +} + // show table sql for postgresql. func (d *dbBasePostgres) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" diff --git a/orm/types.go b/orm/types.go index 41933dd1..cb55e71a 100644 --- a/orm/types.go +++ b/orm/types.go @@ -420,4 +420,5 @@ type dbBaser interface { ShowColumnsQuery(string) string IndexExists(dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error }