1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 20:51:29 +00:00

#436 support insert multi

This commit is contained in:
slene 2014-01-06 11:07:03 +08:00
parent 6f3a759ba5
commit b766f65c26
5 changed files with 167 additions and 51 deletions

117
orm/db.go
View File

@ -51,7 +51,13 @@ type dbBase struct {
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
var columns []string
if names != nil {
columns = *names
}
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil { if fi, _ = mi.fields.GetByAny(column); fi != nil {
@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if names != nil {
columns = append(columns, column) columns = append(columns, column)
}
values = append(values, value) values = append(values, value)
} }
if names != nil {
*names = columns
}
return return
} }
@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
} }
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
} }
whereCols = append(whereCols, pkColumn) whereCols = []string{pkColumn}
args = append(args, pkValue) args = append(args, pkValue)
} }
@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
} }
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) names := make([]string, 0, len(mi.fields.dbcols)-1)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, names, values) return d.InsertValue(q, mi, false, names, values)
} }
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
values []interface{}
names []string
)
// typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len()
for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
} else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return cnt, err
}
if len(vus) != len(names) {
return cnt, ErrArgs
}
nums += copy(values[nums:], vus)
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
cnt += num
nums = 0
}
}
return cnt, nil
}
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
qmarks := strings.Join(marks, ", ") qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep) columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := q.Exec(query, values...); err == nil { if res, err := q.Exec(query, values...); err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId() return res.LastInsertId()
} else { } else {
return 0, err return 0, err
} }
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
} }
@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, ErrMissPK return 0, ErrMissPK
} }
var setNames []string
// if specify cols length is zero, then commit all columns. // if specify cols length is zero, then commit all columns.
if len(cols) == 0 { if len(cols) == 0 {
cols = mi.fields.dbcols cols = mi.fields.dbcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
} else {
setNames = make([]string, 0, len(cols))
} }
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -214,8 +214,6 @@ loopFor:
fi, ok = mmi.fields.GetByAny(ex) fi, ok = mmi.fields.GetByAny(ex)
} }
// fmt.Println(ex, fi.name, fiN)
_ = okN _ = okN
if ok { if ok {

View File

@ -25,6 +25,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows") ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found") ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
@ -39,11 +40,11 @@ type orm struct {
var _ Ormer = new(orm) var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind = reflect.Indirect(val) ind = reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
if val.Kind() != reflect.Ptr { if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
name := getFullName(typ) name := getFullName(typ)
@ -62,7 +63,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
} }
func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
@ -71,12 +72,18 @@ func (o *orm) Read(md interface{}, cols ...string) error {
} }
func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
if id > 0 {
o.setPk(mi, ind, id)
return id, nil
}
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
@ -84,12 +91,44 @@ func (o *orm) Insert(md interface{}) (int64, error) {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id) ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
}
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
} }
return id, nil default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
} }
func (o *orm) Update(md interface{}, cols ...string) (int64, error) { func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return num, err return num, err
@ -98,25 +137,19 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
} }
func (o *orm) Delete(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { o.setPk(mi, ind, 0)
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
} }
return num, nil return num, nil
} }
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
switch { switch {
@ -197,7 +230,7 @@ func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
} }
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind) _, _, exist := getExistPk(mi, ind)

View File

@ -44,7 +44,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column} names := []string{mfi.column, rfi.column}
var nums int64 values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
@ -59,16 +60,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
} }
values := []interface{}{v1, v2} values = append(values, v1, v2)
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
} }
nums += 1 return dbase.InsertValue(orm.db, mi, true, names, values)
}
return nums, nil
} }
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {

View File

@ -21,6 +21,7 @@ type Fielder interface {
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error Read(interface{}, ...string) error
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error) Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error)
@ -109,7 +110,8 @@ type txEnder interface {
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)