mirror of
https://github.com/astaxie/beego.git
synced 2025-07-02 22:40:18 +00:00
#436 support insert multi
This commit is contained in:
119
orm/db.go
119
orm/db.go
@ -51,7 +51,13 @@ type dbBase struct {
|
||||
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
||||
var columns []string
|
||||
|
||||
if names != nil {
|
||||
columns = *names
|
||||
}
|
||||
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||
@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
||||
}
|
||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
columns = append(columns, column)
|
||||
|
||||
if names != nil {
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
*names = columns
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
// if specify cols length > 0, then use it for where condition.
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||
whereCols = make([]string, 0, len(cols))
|
||||
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
whereCols = append(whereCols, pkColumn)
|
||||
whereCols = []string{pkColumn}
|
||||
args = append(args, pkValue)
|
||||
}
|
||||
|
||||
@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
}
|
||||
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return d.InsertValue(q, mi, names, values)
|
||||
return d.InsertValue(q, mi, false, names, values)
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
|
||||
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||
var (
|
||||
cnt int64
|
||||
nums int
|
||||
values []interface{}
|
||||
names []string
|
||||
)
|
||||
|
||||
// typ := reflect.Indirect(mi.addrField).Type()
|
||||
|
||||
length := sind.Len()
|
||||
|
||||
for i := 1; i <= length; i++ {
|
||||
|
||||
ind := reflect.Indirect(sind.Index(i - 1))
|
||||
|
||||
// Is this needed ?
|
||||
// if !ind.Type().AssignableTo(typ) {
|
||||
// return cnt, ErrArgs
|
||||
// }
|
||||
|
||||
if i == 1 {
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
values = make([]interface{}, bulk*len(vus))
|
||||
nums += copy(values, vus)
|
||||
|
||||
} else {
|
||||
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
if len(vus) != len(names) {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
nums += copy(values[nums:], vus)
|
||||
}
|
||||
|
||||
if i > 1 && i%bulk == 0 || length == i {
|
||||
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
cnt += num
|
||||
nums = 0
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
marks := make([]string, len(names))
|
||||
@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
multi := len(values) / len(names)
|
||||
|
||||
if isMulti {
|
||||
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if d.ins.HasReturningID(mi, &query) {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
} else {
|
||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||
if res, err := q.Exec(query, values...); err == nil {
|
||||
if isMulti {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
return res.LastInsertId()
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
|
||||
var setNames []string
|
||||
|
||||
// if specify cols length is zero, then commit all columns.
|
||||
if len(cols) == 0 {
|
||||
cols = mi.fields.dbcols
|
||||
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
} else {
|
||||
setNames = make([]string, 0, len(cols))
|
||||
}
|
||||
|
||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user