1
0
mirror of https://github.com/astaxie/beego.git synced 2024-10-31 23:20:54 +00:00
Beego/orm/db_alias.go

398 lines
9.0 KiB
Go
Raw Normal View History

2014-08-18 08:41:43 +00:00
// Copyright 2014 beego Author. All Rights Reserved.
2014-07-03 15:40:21 +00:00
//
2014-08-18 08:41:43 +00:00
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
2014-07-03 15:40:21 +00:00
//
2014-08-18 08:41:43 +00:00
// http://www.apache.org/licenses/LICENSE-2.0
2014-07-03 15:40:21 +00:00
//
2014-08-18 08:41:43 +00:00
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2013-07-30 12:32:38 +00:00
package orm
import (
2019-06-08 15:53:42 +00:00
"context"
2013-07-30 12:32:38 +00:00
"database/sql"
"fmt"
"reflect"
2013-07-30 12:32:38 +00:00
"sync"
"time"
2013-07-30 12:32:38 +00:00
)
2015-09-12 13:46:43 +00:00
// DriverType database driver constant int.
2013-08-07 11:11:44 +00:00
type DriverType int
2013-07-30 12:32:38 +00:00
2015-09-12 13:46:43 +00:00
// Enum the Database driver
2013-07-30 12:32:38 +00:00
const (
2015-09-12 13:46:43 +00:00
_ DriverType = iota // int enum type
DRMySQL // mysql
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
2015-09-10 08:31:53 +00:00
DRTiDB // TiDB
2013-07-30 12:32:38 +00:00
)
2014-01-17 09:04:15 +00:00
// database driver string.
2013-08-07 11:11:44 +00:00
type driver string
2014-01-17 09:04:15 +00:00
// get type constant int of current driver..
2013-08-07 11:11:44 +00:00
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
2014-01-17 09:04:15 +00:00
// get name of current driver
2013-08-07 11:11:44 +00:00
func (d driver) Name() string {
return string(d)
}
2014-01-17 09:04:15 +00:00
// check driver iis implemented Driver interface or not.
2013-08-07 11:11:44 +00:00
var _ Driver = new(driver)
2013-07-30 12:32:38 +00:00
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
2013-08-07 11:11:44 +00:00
drivers = map[string]DriverType{
2015-09-12 13:46:43 +00:00
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
2015-09-10 08:31:53 +00:00
"tidb": DRTiDB,
2016-03-10 13:23:13 +00:00
"oracle": DROracle,
2017-04-24 13:36:07 +00:00
"oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, //https://github.com/rana/ora
2013-08-01 01:23:32 +00:00
}
2013-08-07 11:11:44 +00:00
dbBasers = map[DriverType]dbBaser{
2015-09-12 13:46:43 +00:00
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
2015-09-17 15:47:26 +00:00
DROracle: newdbBaseOracle(),
2015-09-12 13:46:43 +00:00
DRPostgres: newdbBasePostgres(),
2015-09-10 08:31:53 +00:00
DRTiDB: newdbBaseTidb(),
2013-07-30 12:32:38 +00:00
}
)
2014-01-17 09:04:15 +00:00
// database alias cacher.
2013-07-30 12:32:38 +00:00
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
2014-01-17 09:04:15 +00:00
// add database alias with original name.
2013-07-30 12:32:38 +00:00
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
2016-09-01 15:28:34 +00:00
if _, ok := ac.cache[name]; !ok {
2013-07-30 12:32:38 +00:00
ac.cache[name] = al
added = true
}
return
}
2014-01-17 09:04:15 +00:00
// get database alias if cached.
2013-07-30 12:32:38 +00:00
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
al, ok = ac.cache[name]
return
}
2014-01-17 09:04:15 +00:00
// get default alias.
2013-07-30 12:32:38 +00:00
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
}
2019-06-08 15:53:42 +00:00
type DB struct {
*sync.RWMutex
DB *sql.DB
stmts map[string]*sql.Stmt
}
2019-06-28 15:13:18 +00:00
func (d *DB) Begin() (*sql.Tx, error) {
return d.DB.Begin()
2019-06-28 15:13:18 +00:00
}
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return d.DB.BeginTx(ctx, opts)
2019-06-28 15:13:18 +00:00
}
2019-06-08 15:53:42 +00:00
func (d *DB) getStmt(query string) (*sql.Stmt, error) {
d.RLock()
if stmt, ok := d.stmts[query]; ok {
d.RUnlock()
return stmt, nil
}
2019-06-08 17:19:17 +00:00
d.RUnlock()
2019-06-08 15:53:42 +00:00
stmt, err := d.Prepare(query)
if err != nil {
return nil, err
}
d.Lock()
d.stmts[query] = stmt
d.Unlock()
return stmt, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
2019-06-08 17:19:17 +00:00
return stmt.Exec(args...)
2019-06-08 15:53:42 +00:00
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
2019-06-08 17:19:17 +00:00
return stmt.ExecContext(ctx, args...)
2019-06-08 15:53:42 +00:00
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
2019-06-08 17:19:17 +00:00
return stmt.Query(args...)
2019-06-08 15:53:42 +00:00
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
2019-06-08 17:19:17 +00:00
return stmt.QueryContext(ctx, args...)
2019-06-08 15:53:42 +00:00
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
2019-06-08 17:19:17 +00:00
return stmt.QueryRow(args...)
2019-06-08 15:53:42 +00:00
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
return stmt.QueryRowContext(ctx, args)
}
2013-07-30 12:32:38 +00:00
type alias struct {
Name string
Driver DriverType
DriverName string
DataSource string
MaxIdleConns int
MaxOpenConns int
2019-06-08 15:53:42 +00:00
DB *DB
DbBaser dbBaser
TZ *time.Location
Engine string
2013-07-30 12:32:38 +00:00
}
func detectTZ(al *alias) {
// orm timezone system match database
// default use Local
al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" {
return
}
switch al.Driver {
2015-09-12 13:46:43 +00:00
case DRMySQL:
2014-01-10 08:50:03 +00:00
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
2014-01-10 08:50:03 +00:00
if len(tz) >= 8 {
if tz[0] != '-' {
tz = "+" + tz
2013-12-21 13:00:29 +00:00
}
2014-01-10 08:50:03 +00:00
t, err := time.Parse("-07:00:00", tz)
if err == nil {
if t.Location().String() != "" {
al.TZ = t.Location()
}
2014-01-10 08:50:03 +00:00
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
// get default engine from current database
row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
var engine string
var tx bool
row.Scan(&engine, &tx)
if engine != "" {
al.Engine = engine
} else {
2014-07-04 08:57:37 +00:00
al.Engine = "INNODB"
}
2016-03-10 13:23:13 +00:00
case DRSqlite, DROracle:
al.TZ = time.UTC
2015-09-12 13:46:43 +00:00
case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
2014-01-10 08:50:03 +00:00
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
2019-06-08 15:53:42 +00:00
al.DB = &DB{
2019-06-08 17:19:17 +00:00
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
2019-06-08 15:53:42 +00:00
}
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
2017-03-17 17:24:45 +00:00
if !dataBaseCache.add(aliasName, al) {
2014-03-10 12:50:54 +00:00
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
2015-09-12 13:46:43 +00:00
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
2015-09-12 13:46:43 +00:00
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params {
switch i {
case 0:
SetMaxIdleConns(al.Name, v)
case 1:
SetMaxOpenConns(al.Name, v)
}
}
2013-07-30 12:32:38 +00:00
end:
if err != nil {
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
2013-07-30 12:32:38 +00:00
}
return err
2013-07-30 12:32:38 +00:00
}
2015-09-12 13:46:43 +00:00
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
2017-03-17 17:24:45 +00:00
if t, ok := drivers[driverName]; !ok {
drivers[driverName] = typ
2013-07-30 12:32:38 +00:00
} else {
2013-07-31 14:11:22 +00:00
if t != typ {
2017-04-30 14:41:23 +00:00
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
2013-07-31 14:11:22 +00:00
}
2013-07-30 12:32:38 +00:00
}
return nil
2013-07-30 12:32:38 +00:00
}
2015-09-12 13:46:43 +00:00
// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
2017-04-30 14:41:23 +00:00
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
}
return nil
}
2015-09-12 13:46:43 +00:00
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
2019-06-08 15:53:42 +00:00
al.DB.DB.SetMaxIdleConns(maxIdleConns)
}
2015-09-12 13:46:43 +00:00
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
// for tip go 1.2
if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
}
}
2014-03-10 12:50:54 +00:00
2015-09-12 13:46:43 +00:00
// GetDB Get *sql.DB from registered database by db alias name.
2014-03-10 12:50:54 +00:00
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
if len(aliasNames) > 0 {
name = aliasNames[0]
} else {
name = "default"
}
2015-09-12 13:46:43 +00:00
al, ok := dataBaseCache.get(name)
if ok {
2019-06-08 15:53:42 +00:00
return al.DB.DB, nil
2014-03-10 12:50:54 +00:00
}
2017-04-30 14:41:23 +00:00
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
2014-03-10 12:50:54 +00:00
}