mirror of
https://github.com/astaxie/beego.git
synced 2024-11-01 00:00:54 +00:00
555 lines
13 KiB
Go
555 lines
13 KiB
Go
// Copyright 2014 beego Author. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package orm
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
lru "github.com/hashicorp/golang-lru"
|
|
|
|
"github.com/astaxie/beego/pkg/common"
|
|
)
|
|
|
|
// DriverType database driver constant int.
|
|
type DriverType int
|
|
|
|
// Enum the Database driver
|
|
const (
|
|
_ DriverType = iota // int enum type
|
|
DRMySQL // mysql
|
|
DRSqlite // sqlite
|
|
DROracle // oracle
|
|
DRPostgres // pgsql
|
|
DRTiDB // TiDB
|
|
)
|
|
|
|
// database driver string.
|
|
type driver string
|
|
|
|
// get type constant int of current driver..
|
|
func (d driver) Type() DriverType {
|
|
a, _ := dataBaseCache.get(string(d))
|
|
return a.Driver
|
|
}
|
|
|
|
// get name of current driver
|
|
func (d driver) Name() string {
|
|
return string(d)
|
|
}
|
|
|
|
// check driver iis implemented Driver interface or not.
|
|
var _ Driver = new(driver)
|
|
|
|
var (
|
|
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
|
|
drivers = map[string]DriverType{
|
|
"mysql": DRMySQL,
|
|
"postgres": DRPostgres,
|
|
"sqlite3": DRSqlite,
|
|
"tidb": DRTiDB,
|
|
"oracle": DROracle,
|
|
"oci8": DROracle, // github.com/mattn/go-oci8
|
|
"ora": DROracle, // https://github.com/rana/ora
|
|
}
|
|
dbBasers = map[DriverType]dbBaser{
|
|
DRMySQL: newdbBaseMysql(),
|
|
DRSqlite: newdbBaseSqlite(),
|
|
DROracle: newdbBaseOracle(),
|
|
DRPostgres: newdbBasePostgres(),
|
|
DRTiDB: newdbBaseTidb(),
|
|
}
|
|
)
|
|
|
|
// database alias cacher.
|
|
type _dbCache struct {
|
|
mux sync.RWMutex
|
|
cache map[string]*alias
|
|
}
|
|
|
|
// add database alias with original name.
|
|
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
|
ac.mux.Lock()
|
|
defer ac.mux.Unlock()
|
|
if _, ok := ac.cache[name]; !ok {
|
|
ac.cache[name] = al
|
|
added = true
|
|
}
|
|
return
|
|
}
|
|
|
|
// get database alias if cached.
|
|
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
|
ac.mux.RLock()
|
|
defer ac.mux.RUnlock()
|
|
al, ok = ac.cache[name]
|
|
return
|
|
}
|
|
|
|
// get default alias.
|
|
func (ac *_dbCache) getDefault() (al *alias) {
|
|
al, _ = ac.get("default")
|
|
return
|
|
}
|
|
|
|
type DB struct {
|
|
*sync.RWMutex
|
|
DB *sql.DB
|
|
stmtDecorators *lru.Cache
|
|
stmtDecoratorsLimit int
|
|
}
|
|
|
|
var _ dbQuerier = new(DB)
|
|
var _ txer = new(DB)
|
|
|
|
func (d *DB) Begin() (*sql.Tx, error) {
|
|
return d.DB.Begin()
|
|
}
|
|
|
|
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
|
return d.DB.BeginTx(ctx, opts)
|
|
}
|
|
|
|
// su must call release to release *sql.Stmt after using
|
|
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
|
|
d.RLock()
|
|
c, ok := d.stmtDecorators.Get(query)
|
|
if ok {
|
|
c.(*stmtDecorator).acquire()
|
|
d.RUnlock()
|
|
return c.(*stmtDecorator), nil
|
|
}
|
|
d.RUnlock()
|
|
|
|
d.Lock()
|
|
c, ok = d.stmtDecorators.Get(query)
|
|
if ok {
|
|
c.(*stmtDecorator).acquire()
|
|
d.Unlock()
|
|
return c.(*stmtDecorator), nil
|
|
}
|
|
|
|
stmt, err := d.Prepare(query)
|
|
if err != nil {
|
|
d.Unlock()
|
|
return nil, err
|
|
}
|
|
sd := newStmtDecorator(stmt)
|
|
sd.acquire()
|
|
d.stmtDecorators.Add(query, sd)
|
|
d.Unlock()
|
|
|
|
return sd, nil
|
|
}
|
|
|
|
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
|
|
return d.DB.Prepare(query)
|
|
}
|
|
|
|
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
|
return d.DB.PrepareContext(ctx, query)
|
|
}
|
|
|
|
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
return d.ExecContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
if d.stmtDecorators == nil {
|
|
return d.DB.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
sd, err := d.getStmtDecorator(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stmt := sd.getStmt()
|
|
defer sd.release()
|
|
return stmt.ExecContext(ctx, args...)
|
|
}
|
|
|
|
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
return d.QueryContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
if d.stmtDecorators == nil {
|
|
return d.DB.QueryContext(ctx, query, args...)
|
|
}
|
|
|
|
sd, err := d.getStmtDecorator(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stmt := sd.getStmt()
|
|
defer sd.release()
|
|
return stmt.QueryContext(ctx, args...)
|
|
}
|
|
|
|
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
return d.QueryRowContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
|
if d.stmtDecorators == nil {
|
|
return d.DB.QueryRowContext(ctx, query, args...)
|
|
}
|
|
|
|
sd, err := d.getStmtDecorator(query)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
stmt := sd.getStmt()
|
|
defer sd.release()
|
|
return stmt.QueryRowContext(ctx, args...)
|
|
}
|
|
|
|
type TxDB struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
var _ dbQuerier = new(TxDB)
|
|
var _ txEnder = new(TxDB)
|
|
|
|
func (t *TxDB) Commit() error {
|
|
return t.tx.Commit()
|
|
}
|
|
|
|
func (t *TxDB) Rollback() error {
|
|
return t.tx.Rollback()
|
|
}
|
|
|
|
var _ dbQuerier = new(TxDB)
|
|
var _ txEnder = new(TxDB)
|
|
|
|
func (t *TxDB) Prepare(query string) (*sql.Stmt, error) {
|
|
return t.PrepareContext(context.Background(), query)
|
|
}
|
|
|
|
func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
|
return t.tx.PrepareContext(ctx, query)
|
|
}
|
|
|
|
func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
return t.ExecContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
return t.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
return t.QueryContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
return t.tx.QueryContext(ctx, query, args...)
|
|
}
|
|
|
|
func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
return t.QueryRowContext(context.Background(), query, args...)
|
|
}
|
|
|
|
func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
|
return t.tx.QueryRowContext(ctx, query, args...)
|
|
}
|
|
|
|
type alias struct {
|
|
Name string
|
|
Driver DriverType
|
|
DriverName string
|
|
DataSource string
|
|
MaxIdleConns int
|
|
MaxOpenConns int
|
|
ConnMaxLifetime time.Duration
|
|
DB *DB
|
|
DbBaser dbBaser
|
|
TZ *time.Location
|
|
Engine string
|
|
}
|
|
|
|
func detectTZ(al *alias) {
|
|
// orm timezone system match database
|
|
// default use Local
|
|
al.TZ = DefaultTimeLoc
|
|
|
|
if al.DriverName == "sphinx" {
|
|
return
|
|
}
|
|
|
|
switch al.Driver {
|
|
case DRMySQL:
|
|
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
|
var tz string
|
|
row.Scan(&tz)
|
|
if len(tz) >= 8 {
|
|
if tz[0] != '-' {
|
|
tz = "+" + tz
|
|
}
|
|
t, err := time.Parse("-07:00:00", tz)
|
|
if err == nil {
|
|
if t.Location().String() != "" {
|
|
al.TZ = t.Location()
|
|
}
|
|
} else {
|
|
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
|
}
|
|
}
|
|
|
|
// get default engine from current database
|
|
row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
|
|
var engine string
|
|
var tx bool
|
|
row.Scan(&engine, &tx)
|
|
|
|
if engine != "" {
|
|
al.Engine = engine
|
|
} else {
|
|
al.Engine = "INNODB"
|
|
}
|
|
|
|
case DRSqlite, DROracle:
|
|
al.TZ = time.UTC
|
|
|
|
case DRPostgres:
|
|
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
|
|
var tz string
|
|
row.Scan(&tz)
|
|
loc, err := time.LoadLocation(tz)
|
|
if err == nil {
|
|
al.TZ = loc
|
|
} else {
|
|
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) {
|
|
existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
|
if _, ok := dataBaseCache.get(aliasName); ok {
|
|
return nil, existErr
|
|
}
|
|
|
|
al, err := newAliasWithDb(aliasName, driverName, db, params...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !dataBaseCache.add(aliasName, al) {
|
|
return nil, existErr
|
|
}
|
|
|
|
return al, nil
|
|
}
|
|
|
|
func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) {
|
|
kvs := common.NewKVs(params...)
|
|
|
|
var stmtCache *lru.Cache
|
|
var stmtCacheSize int
|
|
|
|
maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int)
|
|
if maxStmtCacheSize > 0 {
|
|
_stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize)
|
|
if errC != nil {
|
|
return nil, errC
|
|
} else {
|
|
stmtCache = _stmtCache
|
|
stmtCacheSize = maxStmtCacheSize
|
|
}
|
|
}
|
|
|
|
al := new(alias)
|
|
al.Name = aliasName
|
|
al.DriverName = driverName
|
|
al.DB = &DB{
|
|
RWMutex: new(sync.RWMutex),
|
|
DB: db,
|
|
stmtDecorators: stmtCache,
|
|
stmtDecoratorsLimit: stmtCacheSize,
|
|
}
|
|
|
|
if dr, ok := drivers[driverName]; ok {
|
|
al.DbBaser = dbBasers[dr]
|
|
al.Driver = dr
|
|
} else {
|
|
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
|
}
|
|
|
|
err := db.Ping()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
|
}
|
|
|
|
detectTZ(al)
|
|
|
|
kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) {
|
|
if m, ok := value.(int); ok {
|
|
SetMaxIdleConns(al, m)
|
|
}
|
|
}).IfContains(maxOpenConnectionsKey, func(value interface{}) {
|
|
if m, ok := value.(int); ok {
|
|
SetMaxOpenConns(al, m)
|
|
}
|
|
}).IfContains(connMaxLifetimeKey, func(value interface{}) {
|
|
if m, ok := value.(time.Duration); ok {
|
|
SetConnMaxLifetime(al, m)
|
|
}
|
|
})
|
|
|
|
return al, nil
|
|
}
|
|
|
|
// AddAliasWthDB add a aliasName for the drivename
|
|
func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error {
|
|
_, err := addAliasWthDB(aliasName, driverName, db, params...)
|
|
return err
|
|
}
|
|
|
|
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
|
|
func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error {
|
|
var (
|
|
err error
|
|
db *sql.DB
|
|
al *alias
|
|
)
|
|
|
|
db, err = sql.Open(driverName, dataSource)
|
|
if err != nil {
|
|
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
|
goto end
|
|
}
|
|
|
|
al, err = addAliasWthDB(aliasName, driverName, db, hints...)
|
|
if err != nil {
|
|
goto end
|
|
}
|
|
|
|
al.DataSource = dataSource
|
|
|
|
end:
|
|
if err != nil {
|
|
if db != nil {
|
|
db.Close()
|
|
}
|
|
DebugLog.Println(err.Error())
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
|
|
func RegisterDriver(driverName string, typ DriverType) error {
|
|
if t, ok := drivers[driverName]; !ok {
|
|
drivers[driverName] = typ
|
|
} else {
|
|
if t != typ {
|
|
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetDataBaseTZ Change the database default used timezone
|
|
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
|
al.TZ = tz
|
|
} else {
|
|
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
|
|
func SetMaxIdleConns(al *alias, maxIdleConns int) {
|
|
al.MaxIdleConns = maxIdleConns
|
|
al.DB.DB.SetMaxIdleConns(maxIdleConns)
|
|
}
|
|
|
|
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
|
|
func SetMaxOpenConns(al *alias, maxOpenConns int) {
|
|
al.MaxOpenConns = maxOpenConns
|
|
al.DB.DB.SetMaxOpenConns(maxOpenConns)
|
|
}
|
|
|
|
func SetConnMaxLifetime(al *alias, lifeTime time.Duration) {
|
|
al.ConnMaxLifetime = lifeTime
|
|
al.DB.DB.SetConnMaxLifetime(lifeTime)
|
|
}
|
|
|
|
// GetDB Get *sql.DB from registered database by db alias name.
|
|
// Use "default" as alias name if you not set.
|
|
func GetDB(aliasNames ...string) (*sql.DB, error) {
|
|
var name string
|
|
if len(aliasNames) > 0 {
|
|
name = aliasNames[0]
|
|
} else {
|
|
name = "default"
|
|
}
|
|
al, ok := dataBaseCache.get(name)
|
|
if ok {
|
|
return al.DB.DB, nil
|
|
}
|
|
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
|
|
}
|
|
|
|
type stmtDecorator struct {
|
|
wg sync.WaitGroup
|
|
stmt *sql.Stmt
|
|
}
|
|
|
|
func (s *stmtDecorator) getStmt() *sql.Stmt {
|
|
return s.stmt
|
|
}
|
|
|
|
// acquire will add one
|
|
// since this method will be used inside read lock scope,
|
|
// so we can not do more things here
|
|
// we should think about refactor this
|
|
func (s *stmtDecorator) acquire() {
|
|
s.wg.Add(1)
|
|
}
|
|
|
|
func (s *stmtDecorator) release() {
|
|
s.wg.Done()
|
|
}
|
|
|
|
// garbage recycle for stmt
|
|
func (s *stmtDecorator) destroy() {
|
|
go func() {
|
|
s.wg.Wait()
|
|
_ = s.stmt.Close()
|
|
}()
|
|
}
|
|
|
|
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
|
|
return &stmtDecorator{
|
|
stmt: sqlStmt,
|
|
}
|
|
}
|
|
|
|
func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) {
|
|
cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) {
|
|
value.(*stmtDecorator).destroy()
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cache, nil
|
|
}
|