1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 05:11:31 +00:00
This commit is contained in:
BaoyangChai 2019-06-08 23:53:42 +08:00
parent 206a7ed1fc
commit cc0eacbe02
2 changed files with 98 additions and 6 deletions

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
@ -103,6 +104,89 @@ func (ac *_dbCache) getDefault() (al *alias) {
return return
} }
type DB struct {
*sync.RWMutex
DB *sql.DB
stmts map[string]*sql.Stmt
}
func (d *DB) getStmt(query string) (*sql.Stmt, error) {
d.RLock()
if stmt, ok := d.stmts[query]; ok {
d.RUnlock()
return stmt, nil
}
stmt, err := d.Prepare(query)
if err != nil {
return nil, err
}
d.Lock()
d.stmts[query] = stmt
d.Unlock()
return stmt, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Exec(args)
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args)
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Query(args)
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.QueryContext(ctx, args)
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
return nil
}
return stmt.QueryRow(args)
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
return nil
}
return stmt.QueryRowContext(ctx, args)
}
type alias struct { type alias struct {
Name string Name string
Driver DriverType Driver DriverType
@ -110,7 +194,7 @@ type alias struct {
DataSource string DataSource string
MaxIdleConns int MaxIdleConns int
MaxOpenConns int MaxOpenConns int
DB *sql.DB DB *DB
DbBaser dbBaser DbBaser dbBaser
TZ *time.Location TZ *time.Location
Engine string Engine string
@ -176,7 +260,10 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias) al := new(alias)
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = db al.DB = &DB{
DB: db,
stmts: make(map[string]*sql.Stmt),
}
if dr, ok := drivers[driverName]; ok { if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr] al.DbBaser = dbBasers[dr]
@ -272,7 +359,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
func SetMaxIdleConns(aliasName string, maxIdleConns int) { func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns) al.DB.DB.SetMaxIdleConns(maxIdleConns)
} }
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
@ -296,7 +383,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} }
al, ok := dataBaseCache.get(name) al, ok := dataBaseCache.get(name)
if ok { if ok {
return al.DB, nil return al.DB.DB, nil
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
} }

View File

@ -60,6 +60,7 @@ import (
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"sync"
"time" "time"
) )
@ -525,7 +526,7 @@ func (o *orm) Driver() Driver {
// return sql.DBStats for current database // return sql.DBStats for current database
func (o *orm) DBStats() *sql.DBStats { func (o *orm) DBStats() *sql.DBStats {
if o.alias != nil && o.alias.DB != nil { if o.alias != nil && o.alias.DB != nil {
stats := o.alias.DB.Stats() stats := o.alias.DB.DB.Stats()
return &stats return &stats
} }
@ -558,7 +559,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = db al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
}
detectTZ(al) detectTZ(al)