diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 9ddcbea9..8117eaab 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -12,17 +12,6 @@ type dbIndex struct { Sql string } -func getDbAlias(name string) *alias { - if al, ok := dataBaseCache.get(name); ok { - return al - } else { - fmt.Println(fmt.Sprintf("unknown DataBase alias name %s", name)) - os.Exit(2) - } - - return nil -} - func getDbDropSql(al *alias) (sqls []string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") @@ -180,7 +169,14 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex sql += "\n)" if al.Driver == DR_MySQL { - sql += " ENGINE=INNODB" + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine } sql += ";" diff --git a/orm/db_alias.go b/orm/db_alias.go index 34971a44..44f55681 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -4,11 +4,13 @@ import ( "database/sql" "fmt" "os" + "reflect" "sync" "time" ) -const defaultMaxIdle = 30 +const defaultMaxIdleConns = 30 +const defaultMaxOpenConns = 50 type DriverType int @@ -76,26 +78,36 @@ func (ac *_dbCache) getDefault() (al *alias) { } type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdle int - DB *sql.DB - DbBaser dbBaser - TZ *time.Location + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + DB *sql.DB + DbBaser dbBaser + TZ *time.Location + Engine string } -func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { - if maxIdle <= 0 { - maxIdle = defaultMaxIdle +// Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { + maxIdleConns := defaultMaxIdleConns + maxOpenConns := defaultMaxOpenConns + + for i, v := range params { + switch i { + case 0: + maxIdleConns = v + case 1: + maxOpenConns = v + } } al := new(alias) - al.Name = name + al.Name = aliasName al.DriverName = driverName al.DataSource = dataSource - al.MaxIdle = maxIdle var ( err error @@ -109,19 +121,17 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { goto end } - if dataBaseCache.add(name, al) == false { - err = fmt.Errorf("db name `%s` already registered, cannot reuse", name) + if dataBaseCache.add(aliasName, al) == false { + err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName) goto end } al.DB, err = sql.Open(driverName, dataSource) if err != nil { - err = fmt.Errorf("register db `%s`, %s", name, err.Error()) + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) goto end } - al.DB.SetMaxIdleConns(al.MaxIdle) - // orm timezone system match database // default use Local al.TZ = time.Local @@ -137,8 +147,22 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { al.TZ = t.Location() } } + + // get default engine from current database + row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") + var engine string + var tx bool + row.Scan(&engine, &tx) + + if engine != "" { + al.Engine = engine + } else { + engine = "INNODB" + } + case DR_Sqlite: al.TZ = time.UTC + case DR_Postgres: row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") var tz string @@ -149,9 +173,12 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { } } + SetMaxIdleConns(al.Name, maxIdleConns) + SetMaxOpenConns(al.Name, maxOpenConns) + err = al.DB.Ping() if err != nil { - err = fmt.Errorf("register db `%s`, %s", name, err.Error()) + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) goto end } @@ -162,6 +189,7 @@ end: } } +// Register a database driver use specify driver name, this can be definition the driver is which database type. func RegisterDriver(driverName string, typ DriverType) { if t, ok := drivers[driverName]; ok == false { drivers[driverName] = typ @@ -173,11 +201,29 @@ func RegisterDriver(driverName string, typ DriverType) { } } -func SetDataBaseTZ(name string, tz *time.Location) { - if al, ok := dataBaseCache.get(name); ok { +// Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) { + if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz } else { - fmt.Sprintf("DataBase name `%s` not registered\n", name) + fmt.Sprintf("DataBase name `%s` not registered\n", aliasName) os.Exit(2) } } + +// Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.MaxIdleConns = maxIdleConns + al.DB.SetMaxIdleConns(maxIdleConns) +} + +// Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.MaxOpenConns = maxOpenConns + // for tip go 1.2 + if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { + fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) + } +} diff --git a/orm/db_utils.go b/orm/db_utils.go index 4dcdaf18..af242516 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -6,6 +6,15 @@ import ( "time" ) +func getDbAlias(name string) *alias { + if al, ok := dataBaseCache.get(name); ok { + return al + } else { + panic(fmt.Errorf("unknown DataBase alias name %s", name)) + } + return nil +} + func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { fi := mi.fields.pk diff --git a/orm/models_test.go b/orm/models_test.go index 113dd14d..1e1420a9 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -223,4 +223,10 @@ go test -v github.com/astaxie/beego/orm } RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DR_MySQL { + alias.Engine = "INNODB" + } + } diff --git a/orm/models_utils.go b/orm/models_utils.go index 2dcbd646..76ee0b6f 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -26,6 +26,20 @@ func getTableName(val reflect.Value) string { return snakeString(ind.Type().Name()) } +func getTableEngine(val reflect.Value) string { + fun := val.MethodByName("TableEngine") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 { + val := vals[0] + if val.Kind() == reflect.String { + return val.String() + } + } + } + return "" +} + func getTableIndex(val reflect.Value) [][]string { fun := val.MethodByName("TableIndex") if fun.IsValid() {