diff --git a/orm/db_alias.go b/orm/db_alias.go index 51ce10f3..74b5ec7d 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -105,17 +105,61 @@ func (ac *_dbCache) getDefault() (al *alias) { } type DB struct { + inTx bool + tx *sql.Tx *sync.RWMutex DB *sql.DB stmts map[string]*sql.Stmt } func (d *DB) Begin() (*sql.Tx, error) { - return d.DB.Begin() + if d.inTx { + return nil, ErrTxHasBegan + } + tx, err := d.DB.Begin() + if err != nil { + return nil, err + } + d.inTx = true + d.tx = tx + return d.tx, nil } func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, opts) + if d.inTx { + return nil, ErrTxHasBegan + } + tx, err := d.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + d.inTx = true + d.tx = tx + return d.tx, nil +} + +func (d *DB) Commit() error { + if !d.inTx { + return ErrTxDone + } + err := d.tx.Commit() + if err != nil { + return err + } + d.inTx = false + return nil +} + +func (d *DB) Rollback() error { + if !d.inTx { + return ErrTxDone + } + err := d.tx.Commit() + if err != nil { + return err + } + d.inTx = false + return nil } func (d *DB) getStmt(query string) (*sql.Stmt, error) {