From cc0eacbe023b95f74c240b35419c14722df45041 Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Sat, 8 Jun 2019 23:53:42 +0800 Subject: [PATCH 1/7] update --- orm/db_alias.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++--- orm/orm.go | 9 +++-- 2 files changed, 98 insertions(+), 6 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index a43e70e3..2f624da1 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -103,6 +104,89 @@ func (ac *_dbCache) getDefault() (al *alias) { return } +type DB struct { + *sync.RWMutex + DB *sql.DB + stmts map[string]*sql.Stmt +} + +func (d *DB) getStmt(query string) (*sql.Stmt, error) { + d.RLock() + if stmt, ok := d.stmts[query]; ok { + d.RUnlock() + return stmt, nil + } + + stmt, err := d.Prepare(query) + if err != nil { + return nil, err + } + d.Lock() + d.stmts[query] = stmt + d.Unlock() + return stmt, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Exec(args) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Query(args) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.QueryContext(ctx, args) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + return nil + } + return stmt.QueryRow(args) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + return nil + } + return stmt.QueryRowContext(ctx, args) +} + type alias struct { Name string Driver DriverType @@ -110,7 +194,7 @@ type alias struct { DataSource string MaxIdleConns int MaxOpenConns int - DB *sql.DB + DB *DB DbBaser dbBaser TZ *time.Location Engine string @@ -176,7 +260,10 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + DB: db, + stmts: make(map[string]*sql.Stmt), + } if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -272,7 +359,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { func SetMaxIdleConns(aliasName string, maxIdleConns int) { al := getDbAlias(aliasName) al.MaxIdleConns = maxIdleConns - al.DB.SetMaxIdleConns(maxIdleConns) + al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name @@ -296,7 +383,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } al, ok := dataBaseCache.get(name) if ok { - return al.DB, nil + return al.DB.DB, nil } return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) } diff --git a/orm/orm.go b/orm/orm.go index d322881b..0239428f 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -60,6 +60,7 @@ import ( "fmt" "os" "reflect" + "sync" "time" ) @@ -525,7 +526,7 @@ func (o *orm) Driver() Driver { // return sql.DBStats for current database func (o *orm) DBStats() *sql.DBStats { if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.Stats() + stats := o.alias.DB.DB.Stats() return &stats } @@ -558,7 +559,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), + } detectTZ(al) From 873f62edff047d41b2b53a86f1fb9963b40698db Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Sun, 9 Jun 2019 01:19:17 +0800 Subject: [PATCH 2/7] update --- orm/db_alias.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index 2f624da1..a581d82d 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -116,6 +116,7 @@ func (d *DB) getStmt(query string) (*sql.Stmt, error) { d.RUnlock() return stmt, nil } + d.RUnlock() stmt, err := d.Prepare(query) if err != nil { @@ -140,7 +141,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { if err != nil { return nil, err } - return stmt.Exec(args) + return stmt.Exec(args...) } func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { @@ -148,7 +149,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) if err != nil { return nil, err } - return stmt.ExecContext(ctx, args) + return stmt.ExecContext(ctx, args...) } func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { @@ -156,7 +157,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { if err != nil { return nil, err } - return stmt.Query(args) + return stmt.Query(args...) } func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -164,7 +165,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - return stmt.QueryContext(ctx, args) + return stmt.QueryContext(ctx, args...) } func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { @@ -173,7 +174,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { panic(err) return nil } - return stmt.QueryRow(args) + return stmt.QueryRow(args...) } @@ -261,8 +262,9 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { al.Name = aliasName al.DriverName = driverName al.DB = &DB{ - DB: db, - stmts: make(map[string]*sql.Stmt), + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), } if dr, ok := drivers[driverName]; ok { From 06692c3e27a93c09d6ac60ac1931ded2af1ba1b9 Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Mon, 17 Jun 2019 23:38:07 +0800 Subject: [PATCH 3/7] update --- orm/orm.go | 1 - 1 file changed, 1 deletion(-) diff --git a/orm/orm.go b/orm/orm.go index 0239428f..11e38fd9 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -529,7 +529,6 @@ func (o *orm) DBStats() *sql.DBStats { stats := o.alias.DB.DB.Stats() return &stats } - return nil } From 5d0c0a03d70d9b6fe32387e364e57d26f9161231 Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Fri, 28 Jun 2019 22:56:32 +0800 Subject: [PATCH 4/7] update --- orm/db_alias.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index a581d82d..9cb42748 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -172,7 +172,6 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { stmt, err := d.getStmt(query) if err != nil { panic(err) - return nil } return stmt.QueryRow(args...) @@ -183,7 +182,6 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac stmt, err := d.getStmt(query) if err != nil { panic(err) - return nil } return stmt.QueryRowContext(ctx, args) } From 40078cba2cc3e243801b731e9a50aeab20eff4a8 Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Fri, 28 Jun 2019 23:13:18 +0800 Subject: [PATCH 5/7] update --- orm/db_alias.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/orm/db_alias.go b/orm/db_alias.go index 9cb42748..51ce10f3 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -110,6 +110,14 @@ type DB struct { stmts map[string]*sql.Stmt } +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + func (d *DB) getStmt(query string) (*sql.Stmt, error) { d.RLock() if stmt, ok := d.stmts[query]; ok { From 2909ff336667627a34edd5c9625c1e6812a6a85f Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Fri, 28 Jun 2019 23:23:01 +0800 Subject: [PATCH 6/7] update --- orm/db_alias.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index 51ce10f3..74b5ec7d 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -105,17 +105,61 @@ func (ac *_dbCache) getDefault() (al *alias) { } type DB struct { + inTx bool + tx *sql.Tx *sync.RWMutex DB *sql.DB stmts map[string]*sql.Stmt } func (d *DB) Begin() (*sql.Tx, error) { - return d.DB.Begin() + if d.inTx { + return nil, ErrTxHasBegan + } + tx, err := d.DB.Begin() + if err != nil { + return nil, err + } + d.inTx = true + d.tx = tx + return d.tx, nil } func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, opts) + if d.inTx { + return nil, ErrTxHasBegan + } + tx, err := d.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + d.inTx = true + d.tx = tx + return d.tx, nil +} + +func (d *DB) Commit() error { + if !d.inTx { + return ErrTxDone + } + err := d.tx.Commit() + if err != nil { + return err + } + d.inTx = false + return nil +} + +func (d *DB) Rollback() error { + if !d.inTx { + return ErrTxDone + } + err := d.tx.Commit() + if err != nil { + return err + } + d.inTx = false + return nil } func (d *DB) getStmt(query string) (*sql.Stmt, error) { From 5bcde306ea05300ae36282eece1353ce2d82749f Mon Sep 17 00:00:00 2001 From: BaoyangChai Date: Fri, 28 Jun 2019 23:37:32 +0800 Subject: [PATCH 7/7] Revert "update" This reverts commit 2909ff336667627a34edd5c9625c1e6812a6a85f. --- orm/db_alias.go | 48 ++---------------------------------------------- 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index 74b5ec7d..51ce10f3 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -105,61 +105,17 @@ func (ac *_dbCache) getDefault() (al *alias) { } type DB struct { - inTx bool - tx *sql.Tx *sync.RWMutex DB *sql.DB stmts map[string]*sql.Stmt } func (d *DB) Begin() (*sql.Tx, error) { - if d.inTx { - return nil, ErrTxHasBegan - } - tx, err := d.DB.Begin() - if err != nil { - return nil, err - } - d.inTx = true - d.tx = tx - return d.tx, nil + return d.DB.Begin() } func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - if d.inTx { - return nil, ErrTxHasBegan - } - tx, err := d.DB.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - d.inTx = true - d.tx = tx - return d.tx, nil -} - -func (d *DB) Commit() error { - if !d.inTx { - return ErrTxDone - } - err := d.tx.Commit() - if err != nil { - return err - } - d.inTx = false - return nil -} - -func (d *DB) Rollback() error { - if !d.inTx { - return ErrTxDone - } - err := d.tx.Commit() - if err != nil { - return err - } - d.inTx = false - return nil + return d.DB.BeginTx(ctx, opts) } func (d *DB) getStmt(query string) (*sql.Stmt, error) {