diff --git a/orm/db.go b/orm/db.go index d2f8a5b2..1ed83145 100644 --- a/orm/db.go +++ b/orm/db.go @@ -634,18 +634,36 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - pkName, pkValue, ok := getExistPk(mi, ind) - if ok == false { - return 0, ErrMissPK +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + var whereCols []string + var args []interface{} + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return 0, err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if ok == false { + return 0, ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) } Q := d.ins.TableQuote() - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q) + sep := fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, pkValue) + res, err := q.Exec(query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { @@ -659,7 +677,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, []interface{}{pkValue}, tz) + err := d.deleteRels(q, mi, args, tz) if err != nil { return num, err } diff --git a/orm/orm.go b/orm/orm.go index 390d300f..f416fc08 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -234,9 +234,10 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { } // delete model in database -func (o *orm) Delete(md interface{}) (int64, error) { +// cols shows the delete conditions values read from. deafult is pk +func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) + num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ,cols) if err != nil { return num, err } diff --git a/orm/orm_test.go b/orm/orm_test.go index 5e288039..fbf4768d 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -577,6 +577,10 @@ func TestCRUD(t *testing.T) { err = dORM.Read(&ub) throwFail(t, err) throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) } func TestInsertTestData(t *testing.T) { diff --git a/orm/types.go b/orm/types.go index 8c17271d..60a0f1df 100644 --- a/orm/types.go +++ b/orm/types.go @@ -71,7 +71,7 @@ type Ormer interface { // num, err = Ormer.Update(&user, "Langs", "Extra") Update(md interface{}, cols ...string) (int64, error) // delete model in database - Delete(md interface{}) (int64, error) + Delete(md interface{}, cols ...string) (int64, error) // load related models to md model. // args are limit, offset int and order string. // @@ -401,7 +401,7 @@ type dbBaser interface { InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) SupportUpdateJoin() bool UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)