mirror of
https://github.com/astaxie/beego.git
synced 2024-12-23 02:50:49 +00:00
orm 1. add api: NewOrmWithDB, AddAliasWthDB; 2. RawSeter -> add api: RowsToMap, RowsToStruct; 3. RawSeter -> change api: Values, ValuesList, ValuesFlat add optional params comumns.
This commit is contained in:
parent
8296713ba4
commit
9384e87083
@ -1350,6 +1350,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// flag of update joined record.
|
||||
func (d *dbBase) SupportUpdateJoin() bool {
|
||||
return true
|
||||
|
126
orm/db_alias.go
126
orm/db_alias.go
@ -3,7 +3,6 @@ package orm
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
@ -13,11 +12,11 @@ import (
|
||||
type DriverType int
|
||||
|
||||
const (
|
||||
_ DriverType = iota // int enum type
|
||||
DR_MySQL // mysql
|
||||
DR_Sqlite // sqlite
|
||||
DR_Oracle // oracle
|
||||
DR_Postgres // pgsql
|
||||
_ DriverType = iota // int enum type
|
||||
DR_MySQL // mysql
|
||||
DR_Sqlite // sqlite
|
||||
DR_Oracle // oracle
|
||||
DR_Postgres // pgsql
|
||||
)
|
||||
|
||||
// database driver string.
|
||||
@ -96,40 +95,15 @@ type alias struct {
|
||||
Engine string
|
||||
}
|
||||
|
||||
// Setting the database connect params. Use the database driver self dataSource args.
|
||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
al := new(alias)
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
al.DataSource = dataSource
|
||||
|
||||
var (
|
||||
err error
|
||||
)
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
err = fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
goto end
|
||||
}
|
||||
|
||||
if dataBaseCache.add(aliasName, al) == false {
|
||||
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
||||
goto end
|
||||
}
|
||||
|
||||
al.DB, err = sql.Open(driverName, dataSource)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
func detectTZ(al *alias) {
|
||||
// orm timezone system match database
|
||||
// default use Local
|
||||
al.TZ = time.Local
|
||||
|
||||
if al.DriverName == "sphinx" {
|
||||
return
|
||||
}
|
||||
|
||||
switch al.Driver {
|
||||
case DR_MySQL:
|
||||
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
||||
@ -173,6 +147,60 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
|
||||
al := new(alias)
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
al.DB = db
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
}
|
||||
|
||||
err := db.Ping()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
||||
}
|
||||
|
||||
if dataBaseCache.add(aliasName, al) == false {
|
||||
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
||||
}
|
||||
|
||||
return al, nil
|
||||
}
|
||||
|
||||
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
|
||||
_, err := addAliasWthDB(aliasName, driverName, db)
|
||||
return err
|
||||
}
|
||||
|
||||
// Setting the database connect params. Use the database driver self dataSource args.
|
||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
|
||||
var (
|
||||
err error
|
||||
db *sql.DB
|
||||
al *alias
|
||||
)
|
||||
|
||||
db, err = sql.Open(driverName, dataSource)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
al, err = addAliasWthDB(aliasName, driverName, db)
|
||||
if err != nil {
|
||||
goto end
|
||||
}
|
||||
|
||||
al.DataSource = dataSource
|
||||
|
||||
detectTZ(al)
|
||||
|
||||
for i, v := range params {
|
||||
switch i {
|
||||
@ -183,39 +211,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
}
|
||||
}
|
||||
|
||||
err = al.DB.Ping()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
end:
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(2)
|
||||
if db != nil {
|
||||
db.Close()
|
||||
}
|
||||
DebugLog.Println(err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a database driver use specify driver name, this can be definition the driver is which database type.
|
||||
func RegisterDriver(driverName string, typ DriverType) {
|
||||
func RegisterDriver(driverName string, typ DriverType) error {
|
||||
if t, ok := drivers[driverName]; ok == false {
|
||||
drivers[driverName] = typ
|
||||
} else {
|
||||
if t != typ {
|
||||
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
|
||||
os.Exit(2)
|
||||
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the database default used timezone
|
||||
func SetDataBaseTZ(aliasName string, tz *time.Location) {
|
||||
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||
al.TZ = tz
|
||||
} else {
|
||||
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName)
|
||||
os.Exit(2)
|
||||
return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the max idle conns for *sql.DB, use specify database alias name
|
||||
|
33
orm/orm.go
33
orm/orm.go
@ -439,6 +439,12 @@ func (o *orm) Driver() Driver {
|
||||
return driver(o.alias.Name)
|
||||
}
|
||||
|
||||
func (o *orm) GetDB() dbQuerier {
|
||||
panic(ErrNotImplement)
|
||||
// not enough
|
||||
return o.db
|
||||
}
|
||||
|
||||
// create new orm
|
||||
func NewOrm() Ormer {
|
||||
BootStrap() // execute only once
|
||||
@ -450,3 +456,30 @@ func NewOrm() Ormer {
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// create a new ormer object with specify *sql.DB for query
|
||||
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
|
||||
var al *alias
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al = new(alias)
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
}
|
||||
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
|
||||
o := new(orm)
|
||||
o.alias = al
|
||||
|
||||
if Debug {
|
||||
o.db = newDbQueryLog(o.alias, db)
|
||||
} else {
|
||||
o.db = db
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
@ -197,6 +197,36 @@ func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, result, keyCol, valueCol, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, ptrStruct, keyCol, valueCol, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// create new QuerySeter.
|
||||
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
|
||||
o := new(querySet)
|
||||
|
191
orm/orm_raw.go
191
orm/orm_raw.go
@ -518,7 +518,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
|
||||
var (
|
||||
maps []Params
|
||||
lists []ParamsList
|
||||
@ -552,20 +552,38 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
indexs []int
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
if columns, err := rs.Columns(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
if len(needCols) > 0 {
|
||||
indexs = make([]int, 0, len(needCols))
|
||||
} else {
|
||||
indexs = make([]int, 0, len(columns))
|
||||
}
|
||||
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i, _ := range refs {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
|
||||
if len(needCols) > 0 {
|
||||
for _, c := range needCols {
|
||||
if c == cols[i] {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -577,7 +595,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
switch typ {
|
||||
case 1:
|
||||
params := make(Params, len(cols))
|
||||
for i, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params[cols[i]] = value.String
|
||||
@ -588,7 +607,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
maps = append(maps, params)
|
||||
case 2:
|
||||
params := make(ParamsList, 0, len(cols))
|
||||
for _, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params = append(params, value.String)
|
||||
@ -598,7 +618,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
}
|
||||
lists = append(lists, params)
|
||||
case 3:
|
||||
for _, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
list = append(list, value.String)
|
||||
@ -623,19 +644,163 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
|
||||
var (
|
||||
maps Params
|
||||
ind *reflect.Value
|
||||
)
|
||||
|
||||
typ := 0
|
||||
switch container.(type) {
|
||||
case *Params:
|
||||
typ = 1
|
||||
default:
|
||||
typ = 2
|
||||
vl := reflect.ValueOf(container)
|
||||
id := reflect.Indirect(vl)
|
||||
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
|
||||
}
|
||||
|
||||
ind = &id
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
|
||||
var rs *sql.Rows
|
||||
if r, err := o.orm.db.Query(query, args...); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
)
|
||||
|
||||
var (
|
||||
keyIndex = -1
|
||||
valueIndex = -1
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
if columns, err := rs.Columns(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i, _ := range refs {
|
||||
if keyCol == cols[i] {
|
||||
keyIndex = i
|
||||
}
|
||||
|
||||
if typ == 1 || keyIndex == i {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
if valueCol == cols[i] {
|
||||
valueIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if keyIndex == -1 || valueIndex == -1 {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rs.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
switch typ {
|
||||
case 1:
|
||||
maps = make(Params)
|
||||
}
|
||||
}
|
||||
|
||||
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
|
||||
|
||||
switch typ {
|
||||
case 1:
|
||||
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
maps[key] = value.String
|
||||
} else {
|
||||
maps[key] = nil
|
||||
}
|
||||
|
||||
default:
|
||||
if id := ind.FieldByName(camelString(key)); id.IsValid() {
|
||||
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if typ == 1 {
|
||||
v, _ := container.(*Params)
|
||||
*v = maps
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// query data to []map[string]interface
|
||||
func (o *rawSet) Values(container *[]Params) (int64, error) {
|
||||
return o.readValues(container)
|
||||
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query data to [][]interface
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query data to []interface
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(result, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// return prepared raw statement for used in times.
|
||||
|
20
orm/types.go
20
orm/types.go
@ -37,6 +37,7 @@ type Ormer interface {
|
||||
Rollback() error
|
||||
Raw(string, ...interface{}) RawSeter
|
||||
Driver() Driver
|
||||
GetDB() dbQuerier
|
||||
}
|
||||
|
||||
// insert prepared statement
|
||||
@ -64,6 +65,8 @@ type QuerySeter interface {
|
||||
Values(*[]Params, ...string) (int64, error)
|
||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||
ValuesFlat(*ParamsList, string) (int64, error)
|
||||
RowsToMap(*Params, string, string) (int64, error)
|
||||
RowsToStruct(interface{}, string, string) (int64, error)
|
||||
}
|
||||
|
||||
// model to model query struct
|
||||
@ -87,9 +90,11 @@ type RawSeter interface {
|
||||
QueryRow(...interface{}) error
|
||||
QueryRows(...interface{}) (int64, error)
|
||||
SetArgs(...interface{}) RawSeter
|
||||
Values(*[]Params) (int64, error)
|
||||
ValuesList(*[]ParamsList) (int64, error)
|
||||
ValuesFlat(*ParamsList) (int64, error)
|
||||
Values(*[]Params, ...string) (int64, error)
|
||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||
ValuesFlat(*ParamsList, ...string) (int64, error)
|
||||
RowsToMap(*Params, string, string) (int64, error)
|
||||
RowsToStruct(interface{}, string, string) (int64, error)
|
||||
Prepare() (RawPreparer, error)
|
||||
}
|
||||
|
||||
@ -109,6 +114,14 @@ type dbQuerier interface {
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// type DB interface {
|
||||
// Begin() (*sql.Tx, error)
|
||||
// Prepare(query string) (stmtQuerier, error)
|
||||
// Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
// Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
// QueryRow(query string, args ...interface{}) *sql.Row
|
||||
// }
|
||||
|
||||
// transaction beginner
|
||||
type txer interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
@ -139,6 +152,7 @@ type dbBaser interface {
|
||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
|
||||
MaxLimit() uint64
|
||||
TableQuote() string
|
||||
ReplaceMarks(*string)
|
||||
|
Loading…
Reference in New Issue
Block a user