diff --git a/orm/db_alias.go b/orm/db_alias.go index a43e70e3..51ce10f3 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -103,6 +104,96 @@ func (ac *_dbCache) getDefault() (al *alias) { return } +type DB struct { + *sync.RWMutex + DB *sql.DB + stmts map[string]*sql.Stmt +} + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +func (d *DB) getStmt(query string) (*sql.Stmt, error) { + d.RLock() + if stmt, ok := d.stmts[query]; ok { + d.RUnlock() + return stmt, nil + } + d.RUnlock() + + stmt, err := d.Prepare(query) + if err != nil { + return nil, err + } + d.Lock() + d.stmts[query] = stmt + d.Unlock() + return stmt, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Exec(args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Query(args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + } + return stmt.QueryRow(args...) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + } + return stmt.QueryRowContext(ctx, args) +} + type alias struct { Name string Driver DriverType @@ -110,7 +201,7 @@ type alias struct { DataSource string MaxIdleConns int MaxOpenConns int - DB *sql.DB + DB *DB DbBaser dbBaser TZ *time.Location Engine string @@ -176,7 +267,11 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), + } if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -272,7 +367,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { func SetMaxIdleConns(aliasName string, maxIdleConns int) { al := getDbAlias(aliasName) al.MaxIdleConns = maxIdleConns - al.DB.SetMaxIdleConns(maxIdleConns) + al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name @@ -296,7 +391,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } al, ok := dataBaseCache.get(name) if ok { - return al.DB, nil + return al.DB.DB, nil } return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) } diff --git a/orm/orm.go b/orm/orm.go index d322881b..11e38fd9 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -60,6 +60,7 @@ import ( "fmt" "os" "reflect" + "sync" "time" ) @@ -525,10 +526,9 @@ func (o *orm) Driver() Driver { // return sql.DBStats for current database func (o *orm) DBStats() *sql.DBStats { if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.Stats() + stats := o.alias.DB.DB.Stats() return &stats } - return nil } @@ -558,7 +558,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), + } detectTZ(al)