diff --git a/orm/db_alias.go b/orm/db_alias.go index 74b5ec7d..51ce10f3 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -105,61 +105,17 @@ func (ac *_dbCache) getDefault() (al *alias) { } type DB struct { - inTx bool - tx *sql.Tx *sync.RWMutex DB *sql.DB stmts map[string]*sql.Stmt } func (d *DB) Begin() (*sql.Tx, error) { - if d.inTx { - return nil, ErrTxHasBegan - } - tx, err := d.DB.Begin() - if err != nil { - return nil, err - } - d.inTx = true - d.tx = tx - return d.tx, nil + return d.DB.Begin() } func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - if d.inTx { - return nil, ErrTxHasBegan - } - tx, err := d.DB.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - d.inTx = true - d.tx = tx - return d.tx, nil -} - -func (d *DB) Commit() error { - if !d.inTx { - return ErrTxDone - } - err := d.tx.Commit() - if err != nil { - return err - } - d.inTx = false - return nil -} - -func (d *DB) Rollback() error { - if !d.inTx { - return ErrTxDone - } - err := d.tx.Commit() - if err != nil { - return err - } - d.inTx = false - return nil + return d.DB.BeginTx(ctx, opts) } func (d *DB) getStmt(query string) (*sql.Stmt, error) {