From 32da446eb1d8785b70037d648ca74261cb20fd2f Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Sun, 19 Jul 2020 23:46:42 +0800 Subject: [PATCH 1/4] refactor orm --- pkg/orm/db_alias.go | 53 +++++++ pkg/orm/orm.go | 307 +++++++++++++++++++++++++--------------- pkg/orm/orm_object.go | 4 +- pkg/orm/orm_querym2m.go | 2 +- pkg/orm/orm_queryset.go | 4 +- pkg/orm/orm_raw.go | 4 +- pkg/orm/orm_test.go | 34 ++--- pkg/orm/types.go | 162 +++++++++++++-------- 8 files changed, 370 insertions(+), 200 deletions(-) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index bf6c350c..b2a72f56 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -111,6 +111,9 @@ type DB struct { stmtDecorators *lru.Cache } +var _ dbQuerier = new(DB) +var _ txer = new(DB) + func (d *DB) Begin() (*sql.Tx, error) { return d.DB.Begin() } @@ -220,6 +223,56 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac return stmt.QueryRowContext(ctx, args) } +type TxDB struct { + tx *sql.Tx +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Commit() error { + return t.tx.Commit() +} + +func (t *TxDB) Rollback() error { + return t.tx.Rollback() +} + +var _ dbQuerier = new(TxDB) +var _ txEnder = new(TxDB) + +func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { + return t.PrepareContext(context.Background(),query) +} + +func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return t.tx.PrepareContext(ctx, query) +} + +func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.ExecContext(context.Background(), query, args...) +} + +func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return t.QueryContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { + return t.QueryRowContext(context.Background(),query,args...) +} + +func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + type alias struct { Name string Driver DriverType diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index bd2cb783..3db75751 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -62,6 +62,8 @@ import ( "reflect" "sync" "time" + + "github.com/astaxie/beego/logs" ) // DebugQueries define the debug @@ -76,8 +78,7 @@ var ( DefaultRowsLimit = -1 DefaultRelsDepth = 2 DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") + ErrTxDone = errors.New(" transaction already done") ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") @@ -91,16 +92,16 @@ type Params map[string]interface{} // ParamsList stores paramslist type ParamsList []interface{} -type orm struct { +type ormBase struct { alias *alias db dbQuerier - isTx bool } -var _ Ormer = new(orm) +var _ DQL = new(ormBase) +var _ DML = new(ormBase) // get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() @@ -115,7 +116,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect } // get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { +func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) @@ -124,33 +125,42 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } // read data to model -func (o *orm) Read(md interface{}, cols ...string) error { +func (o *ormBase) Read(md interface{}, cols ...string) error { + return o.ReadWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { +func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { + return o.ReadForUpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { +func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} +func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create - id, err := o.Insert(md) - return (err == nil), id, err + id, err := o.InsertWithCtx(ctx, md) + return err == nil, id, err } id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { id = int64(vid.Uint()) } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) } else { id = vid.Int() } @@ -159,7 +169,10 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i } // insert model data to database -func (o *orm) Insert(md interface{}) (int64, error) { +func (o *ormBase) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} +func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { @@ -172,7 +185,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { } // set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { +func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) @@ -183,7 +196,10 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } // insert some models to database -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { +func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { + return o.InsertMultiWithCtx(context.Background(), bulk, mds) +} +func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var cnt int64 sind := reflect.Indirect(reflect.ValueOf(mds)) @@ -218,7 +234,10 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { } // InsertOrUpdate data to database -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { +func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { + return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) +} +func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { @@ -232,14 +251,20 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // update model to database. // cols set the columns those want to update. -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { + return o.UpdateWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) } // delete model in database // cols shows the delete conditions values read from. default is pk -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { +func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { + return o.DeleteWithCtx(context.Background(), md, cols...) +} +func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) if err != nil { @@ -252,7 +277,10 @@ func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { } // create a models to models queryer -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { +func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { + return o.QueryM2MWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -274,7 +302,10 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + return o.LoadRelatedWithCtx(context.Background(), md, name, args...) +} +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) @@ -341,14 +372,17 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int // qs := orm.QueryRelated(post,"Tag") // qs.All(&[]*Tag{}) // -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { +func (o *ormBase) QueryRelated(md interface{}, name string) QuerySeter { + return o.QueryRelatedWithCtx(context.Background(), md, name) +} +func (o *ormBase) QueryRelatedWithCtx(ctx context.Context, md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } // get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -380,7 +414,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, } // get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: default: @@ -401,7 +435,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS } // get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { +func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: default: @@ -423,7 +457,10 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { +func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +} +func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -442,11 +479,136 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { return } -// switch to another registered database driver by given name. -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) +// return a raw query seter for raw sql string. +func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { + return o.RawWithCtx(context.Background(), query, args...) +} +func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *ormBase) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *ormBase) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats } + return nil +} + +type orm struct { + ormBase +} + +var _ Ormer = new(orm) + +func (o *orm) Begin() (TxOrmer, error) { + return o.BeginWithCtx(context.Background()) +} + +func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(ctx, nil) +} + +func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + _txOrm := &txOrm{ + ormBase: ormBase{ + alias: o.alias, + db: &TxDB{tx: tx}, + }, + isClosed: false, + } + + var taskTxOrm TxOrmer = _txOrm + return taskTxOrm, nil +} + +func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtx(context.Background(), task) +} + +func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { + _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + panicked := true + defer func() { + if panicked || err != nil { + e := _txOrm.Rollback() + logs.Error("rollback transaction failed: %v", e) + } else { + e := _txOrm.Commit() + logs.Error("commit transaction failed: %v", e) + } + }() + + var taskTxOrm = _txOrm + err = task(taskTxOrm) + panicked = false + return err +} + +type txOrm struct { + ormBase + isClosed bool + closeMutex sync.Mutex +} + +var _ TxOrmer = new(txOrm) + +func (t *txOrm) Commit() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Commit() +} + +func (t *txOrm) Rollback() error { + t.closeMutex.Lock() + defer t.closeMutex.Unlock() + + if t.isClosed { + return ErrTxDone + } + t.isClosed = true + + return t.db.(txEnder).Rollback() +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + name := `default` if al, ok := dataBaseCache.get(name); ok { o.alias = al if Debug { @@ -455,92 +617,9 @@ func (o *orm) Using(name string) error { o.db = al.DB } } else { - return fmt.Errorf(" unknown db alias name `%s`", name) + panic(fmt.Errorf(" unknown db alias name `%s`", name)) } - return nil -} -// begin transaction -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.DB.Stats() - return &stats - } - return nil -} - -// NewOrm create new orm -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } return o } diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go index de3181ce..6f9798d3 100644 --- a/pkg/orm/orm_object.go +++ b/pkg/orm/orm_object.go @@ -22,7 +22,7 @@ import ( // an insert queryer struct type insertSet struct { mi *modelInfo - orm *orm + orm *ormBase stmt stmtQuerier closed bool } @@ -70,7 +70,7 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { +func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go index 6a270a0d..17e1b5d1 100644 --- a/pkg/orm/orm_querym2m.go +++ b/pkg/orm/orm_querym2m.go @@ -129,7 +129,7 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md qm2m.mi = mi diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go index 878b836b..83168de7 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/orm/orm_queryset.go @@ -72,7 +72,7 @@ type querySet struct { orders []string distinct bool forupdate bool - orm *orm + orm *ormBase ctx context.Context forContext bool } @@ -292,7 +292,7 @@ func (o querySet) WithContext(ctx context.Context) QuerySeter { } // create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { +func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) o.mi = mi o.orm = orm diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go index 3325a7ea..5e05eded 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/orm/orm_raw.go @@ -63,7 +63,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { type rawSet struct { query string args []interface{} - orm *orm + orm *ormBase } var _ RawSeter = new(rawSet) @@ -858,7 +858,7 @@ func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { +func newRawSet(orm *ormBase, query string, args []interface{}) RawSeter { o := new(rawSet) o.query = query o.args = args diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index eac7b33a..b7b2d9a7 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2026,24 +2026,24 @@ func TestTransaction(t *testing.T) { // this test worked when database support transaction o := NewOrm() - err := o.Begin() + to, err := o.Begin() throwFail(t, err) var names = []string{"1", "2", "3"} var tag Tag tag.Name = names[0] - id, err := o.Insert(&tag) + id, err := to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + num, err := to.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) throwFail(t, AssertIs(num, 1)) switch { case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { id, err = res.LastInsertId() @@ -2052,22 +2052,22 @@ func TestTransaction(t *testing.T) { } } - err = o.Rollback() + err = to.Rollback() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) - err = o.Begin() + to, err = o.Begin() throwFail(t, err) tag.Name = "commit" - id, err = o.Insert(&tag) + id, err = to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - o.Commit() + to.Commit() throwFail(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() @@ -2086,15 +2086,15 @@ func TestTransactionIsolationLevel(t *testing.T) { o2 := NewOrm() // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to1, err := o1.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to2, err := o2.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) // o1 insert tag var tag Tag tag.Name = "test-transaction" - id, err := o1.Insert(&tag) + id, err := to1.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) @@ -2104,15 +2104,15 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(num, 0)) // o1 commit - o1.Commit() + to1.Commit() // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err = to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) // o2 commit and query tag table, get the result - o2.Commit() + to2.Commit() num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -2125,14 +2125,14 @@ func TestTransactionIsolationLevel(t *testing.T) { func TestBeginTxWithContextCanceled(t *testing.T) { o := NewOrm() ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) + to, _ := o.BeginWithCtx(ctx) + id, err := to.Insert(&Tag{Name: "test-context"}) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) // cancel the context before commit to make it error cancel() - err = o.Commit() + err = to.Commit() throwFail(t, AssertIs(err, context.Canceled)) } diff --git a/pkg/orm/types.go b/pkg/orm/types.go index 2fd10774..b7a38826 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -35,35 +35,43 @@ type Fielder interface { RawValue() interface{} } -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) +type TxBeginner interface { + //self control transaction + Begin() (TxOrmer, error) + BeginWithCtx(ctx context.Context) (TxOrmer, error) + BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) + BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) + + //closure control transaction + DoTx(task func(txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error +} + +type TxCommitter interface { + Commit() error + Rollback() error +} + +//Data Manipulation Language +type DML interface { // insert model data to database // for example: // user := new(User) // id, err = Ormer.Insert(user) // user must be a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) + Insert(md interface{}) (int64, error) + InsertWithCtx(ctx context.Context, md interface{}) (int64, error) // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // if colu type is integer : can use(+-*/), string : convert(colu,"value") // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") // if colu type is integer : can use(+-*/), string : colu || "value" InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) // insert some models to database InsertMulti(bulk int, mds interface{}) (int64, error) + InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) // update model to database. // cols set the columns those want to update. // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns @@ -74,63 +82,93 @@ type Ormer interface { // user.Extra.Data = "orm" // num, err = Ormer.Update(&user, "Langs", "Extra") Update(md interface{}, cols ...string) (int64, error) + UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) // delete model in database Delete(md interface{}, cols ...string) (int64, error) + DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) + + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter +} + +// Data Query Language +type DQL interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate( md interface{}, cols ...string) error + ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) + // load related models to md model. // args are limit, offset int and order string. // // example: // Ormer.LoadRelated(post,"Tags") // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated( md interface{}, name string, args ...interface{}) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer + QueryM2M( md interface{}, name string) QueryM2Mer + QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver + // Using(name string) error + DBStats() *sql.DBStats } +type DriverGetter interface { + Driver() Driver +} + +type Ormer interface { + DQL + DML + DriverGetter + TxBeginner +} + +type TxOrmer interface { + DQL + DML + DriverGetter + TxCommitter +} + // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) @@ -229,7 +267,7 @@ type QuerySeter interface { // }) // user slene's name will change to slene2 Update(values Params) (int64, error) // delete from table - //for example: + // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) @@ -314,8 +352,8 @@ type QueryM2Mer interface { // remove models following the origin model relationship // only delete rows from m2m table // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) + // tag3 := &Tag{Id:5,Name: "TestTag3"} + // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool @@ -337,10 +375,10 @@ type RawPreparer interface { // sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) // rs := Ormer.Raw(sql, 1) type RawSeter interface { - //execute sql and get result + // execute sql and get result Exec() (sql.Result, error) - //query data and map to container - //for example: + // query data and map to container + // for example: // var name string // var id int // rs.QueryRow(&id,&name) // id==2 name=="slene" @@ -396,11 +434,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) + // QueryContext(args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier From aefe21b63af934b346960b4a99d560936e2b0b49 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:25:27 +0800 Subject: [PATCH 2/4] complete error log --- pkg/orm/orm.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 3db75751..24632270 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -558,10 +558,14 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task defer func() { if panicked || err != nil { e := _txOrm.Rollback() - logs.Error("rollback transaction failed: %v", e) + if e != nil { + logs.Error("rollback transaction failed: %v,%v", e, panicked) + } } else { e := _txOrm.Commit() - logs.Error("commit transaction failed: %v", e) + if e != nil { + logs.Error("commit transaction failed: %v,%v", e, panicked) + } } }() From 4aad313de7fbf4da9fd74e89d1e722f2702a29b2 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 17:34:58 +0800 Subject: [PATCH 3/4] do not judge tx status in txOrm --- pkg/orm/orm.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go index 24632270..8ef761f4 100644 --- a/pkg/orm/orm.go +++ b/pkg/orm/orm.go @@ -530,7 +530,6 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO alias: o.alias, db: &TxDB{tx: tx}, }, - isClosed: false, } var taskTxOrm TxOrmer = _txOrm @@ -577,33 +576,15 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task type txOrm struct { ormBase - isClosed bool - closeMutex sync.Mutex } var _ TxOrmer = new(txOrm) func (t *txOrm) Commit() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Commit() } func (t *txOrm) Rollback() error { - t.closeMutex.Lock() - defer t.closeMutex.Unlock() - - if t.isClosed { - return ErrTxDone - } - t.isClosed = true - return t.db.(txEnder).Rollback() } From b6f7d30f9f6192d6046e6b9db0f6a4fd261a829e Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 20 Jul 2020 19:10:57 +0800 Subject: [PATCH 4/4] fix unit test --- pkg/orm/orm_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go index b7b2d9a7..54ecc0fd 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/orm/orm_test.go @@ -2099,7 +2099,7 @@ func TestTransactionIsolationLevel(t *testing.T) { throwFail(t, AssertIs(id > 0, true)) // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err := to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0))