From b766f65c268c3d55aeea4a3380b651660c0029e4 Mon Sep 17 00:00:00 2001 From: slene Date: Mon, 6 Jan 2014 11:07:03 +0800 Subject: [PATCH] #436 support insert multi --- orm/db.go | 119 ++++++++++++++++++++++++++++++++++++++------ orm/db_tables.go | 2 - orm/orm.go | 81 +++++++++++++++++++++--------- orm/orm_querym2m.go | 12 ++--- orm/types.go | 4 +- 5 files changed, 167 insertions(+), 51 deletions(-) diff --git a/orm/db.go b/orm/db.go index 66ae498e..c6e92ec9 100644 --- a/orm/db.go +++ b/orm/db.go @@ -51,7 +51,13 @@ type dbBase struct { var _ dbBaser = new(dbBase) -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { + var columns []string + + if names != nil { + columns = *names + } + for _, column := range cols { var fi *fieldInfo if fi, _ = mi.fields.GetByAny(column); fi != nil { @@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) if err != nil { - return nil, nil, err + return nil, err } - columns = append(columns, column) + + if names != nil { + columns = append(columns, column) + } + values = append(values, value) } + + if names != nil { + *names = columns + } + return } @@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, } func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - _, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) + values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err } @@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo // if specify cols length > 0, then use it for where condition. if len(cols) > 0 { var err error - whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) + whereCols = make([]string, 0, len(cols)) + args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) if err != nil { return err } @@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo if ok == false { return ErrMissPK } - whereCols = append(whereCols, pkColumn) + whereCols = []string{pkColumn} args = append(args, pkValue) } @@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) + names := make([]string, 0, len(mi.fields.dbcols)-1) + values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) if err != nil { return 0, err } - return d.InsertValue(q, mi, names, values) + return d.InsertValue(q, mi, false, names, values) } -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { + var ( + cnt int64 + nums int + values []interface{} + names []string + ) + + // typ := reflect.Indirect(mi.addrField).Type() + + length := sind.Len() + + for i := 1; i <= length; i++ { + + ind := reflect.Indirect(sind.Index(i - 1)) + + // Is this needed ? + // if !ind.Type().AssignableTo(typ) { + // return cnt, ErrArgs + // } + + if i == 1 { + vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + if err != nil { + return cnt, err + } + values = make([]interface{}, bulk*len(vus)) + nums += copy(values, vus) + + } else { + + vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + if err != nil { + return cnt, err + } + + if len(vus) != len(names) { + return cnt, ErrArgs + } + + nums += copy(values[nums:], vus) + } + + if i > 1 && i%bulk == 0 || length == i { + num, err := d.InsertValue(q, mi, true, names, values[:nums]) + if err != nil { + return cnt, err + } + cnt += num + nums = 0 + } + } + + return cnt, nil +} + +func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values qmarks := strings.Join(marks, ", ") columns := strings.Join(names, sep) + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) - if d.ins.HasReturningID(mi, &query) { - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err - } else { + if isMulti || !d.ins.HasReturningID(mi, &query) { if res, err := q.Exec(query, values...); err == nil { + if isMulti { + return res.RowsAffected() + } return res.LastInsertId() } else { return 0, err } + } else { + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err } } @@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, ErrMissPK } + var setNames []string + // if specify cols length is zero, then commit all columns. if len(cols) == 0 { cols = mi.fields.dbcols + setNames = make([]string, 0, len(mi.fields.dbcols)-1) + } else { + setNames = make([]string, 0, len(cols)) } - setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) + setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) if err != nil { return 0, err } diff --git a/orm/db_tables.go b/orm/db_tables.go index 5a78cf21..f5cacf38 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -214,8 +214,6 @@ loopFor: fi, ok = mmi.fields.GetByAny(ex) } - // fmt.Println(ex, fi.name, fiN) - _ = okN if ok { diff --git a/orm/orm.go b/orm/orm.go index 0069aa1d..9e3c3565 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -25,6 +25,7 @@ var ( ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") ErrNotImplement = errors.New("have not implement") ) @@ -39,11 +40,11 @@ type orm struct { var _ Ormer = new(orm) -func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { +func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() - if val.Kind() != reflect.Ptr { + if needPtr && val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } name := getFullName(typ) @@ -62,7 +63,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } func (o *orm) Read(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return err @@ -71,25 +72,63 @@ func (o *orm) Read(md interface{}, cols ...string) error { } func (o *orm) Insert(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } - if id > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { - ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.Field(mi.fields.pk.fieldIndex).SetInt(id) - } - } - } + + o.setPk(mi, ind, id) + return id, nil } +func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.Field(mi.fields.pk.fieldIndex).SetInt(id) + } + } +} + +func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := sind.Index(i) + mi, _ := o.getMiInd(ind.Interface(), false) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt += 1 + } + } else { + mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + } + return cnt, nil +} + func (o *orm) Update(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err @@ -98,25 +137,19 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { } func (o *orm) Delete(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) if err != nil { return num, err } if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { - ind.Field(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.Field(mi.fields.pk.fieldIndex).SetInt(0) - } - } + o.setPk(mi, ind, 0) } return num, nil } func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) switch { @@ -197,7 +230,7 @@ func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { } func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 876fc37e..6f0544d0 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -44,7 +44,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { names := []string{mfi.column, rfi.column} - var nums int64 + values := make([]interface{}, 0, len(models)*2) + for _, md := range models { ind := reflect.Indirect(reflect.ValueOf(md)) @@ -59,16 +60,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } } - values := []interface{}{v1, v2} - _, err := dbase.InsertValue(orm.db, mi, names, values) - if err != nil { - return nums, err - } + values = append(values, v1, v2) - nums += 1 } - return nums, nil + return dbase.InsertValue(orm.db, mi, true, names, values) } func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { diff --git a/orm/types.go b/orm/types.go index 4749124c..a6487fc0 100644 --- a/orm/types.go +++ b/orm/types.go @@ -21,6 +21,7 @@ type Fielder interface { type Ormer interface { Read(interface{}, ...string) error Insert(interface{}) (int64, error) + InsertMulti(int, interface{}) (int64, error) Update(interface{}, ...string) (int64, error) Delete(interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error) @@ -109,7 +110,8 @@ type txEnder interface { type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) + InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)