1
0
mirror of https://github.com/astaxie/beego.git synced 2025-01-22 11:37:12 +00:00

orm 1. add api: NewOrmWithDB, AddAliasWthDB; 2. RawSeter -> add api: RowsToMap, RowsToStruct; 3. RawSeter -> change api: Values, ValuesList, ValuesFlat add optional params comumns.

This commit is contained in:
slene 2014-01-27 01:48:00 +08:00
parent 8296713ba4
commit 9384e87083
6 changed files with 338 additions and 66 deletions

View File

@ -1350,6 +1350,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil
}
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
return 0, nil
}
// flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool {
return true

View File

@ -3,7 +3,6 @@ package orm
import (
"database/sql"
"fmt"
"os"
"reflect"
"sync"
"time"
@ -13,11 +12,11 @@ import (
type DriverType int
const (
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
)
// database driver string.
@ -96,40 +95,15 @@ type alias struct {
Engine string
}
// Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DataSource = dataSource
var (
err error
)
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
}
if dataBaseCache.add(aliasName, al) == false {
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
goto end
}
al.DB, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
func detectTZ(al *alias) {
// orm timezone system match database
// default use Local
al.TZ = time.Local
if al.DriverName == "sphinx" {
return
}
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
@ -173,6 +147,60 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = db
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
if dataBaseCache.add(aliasName, al) == false {
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params {
switch i {
@ -183,39 +211,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
}
}
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
end:
if err != nil {
fmt.Println(err.Error())
os.Exit(2)
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
}
return err
}
// Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) {
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
} else {
if t != typ {
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
os.Exit(2)
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
}
}
return nil
}
// Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) {
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName)
os.Exit(2)
return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
}
return nil
}
// Change the max idle conns for *sql.DB, use specify database alias name

View File

@ -439,6 +439,12 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
func (o *orm) GetDB() dbQuerier {
panic(ErrNotImplement)
// not enough
return o.db
}
// create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@ -450,3 +456,30 @@ func NewOrm() Ormer {
}
return o
}
// create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
if dr, ok := drivers[driverName]; ok {
al = new(alias)
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
al.Name = aliasName
al.DriverName = driverName
o := new(orm)
o.alias = al
if Debug {
o.db = newDbQueryLog(o.alias, db)
} else {
o.db = db
}
return o, nil
}

View File

@ -197,6 +197,36 @@ func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, result, keyCol, valueCol, o.orm.alias.TZ)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, ptrStruct, keyCol, valueCol, o.orm.alias.TZ)
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)

View File

@ -518,7 +518,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return cnt, nil
}
func (o *rawSet) readValues(container interface{}) (int64, error) {
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
var (
maps []Params
lists []ParamsList
@ -552,20 +552,38 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
refs []interface{}
cnt int64
cols []string
indexs []int
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
indexs = make([]int, 0, len(columns))
}
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
var ref sql.NullString
refs[i] = &ref
if len(needCols) > 0 {
for _, c := range needCols {
if c == cols[i] {
indexs = append(indexs, i)
}
}
} else {
indexs = append(indexs, i)
}
}
}
}
@ -577,7 +595,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
switch typ {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
params[cols[i]] = value.String
@ -588,7 +607,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for _, ref := range refs {
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
params = append(params, value.String)
@ -598,7 +618,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
}
lists = append(lists, params)
case 3:
for _, ref := range refs {
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
list = append(list, value.String)
@ -623,19 +644,163 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil
}
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
var (
maps Params
ind *reflect.Value
)
typ := 0
switch container.(type) {
case *Params:
typ = 1
default:
typ = 2
vl := reflect.ValueOf(container)
id := reflect.Indirect(vl)
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
}
ind = &id
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
)
var (
keyIndex = -1
valueIndex = -1
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 {
switch typ {
case 1:
maps = make(Params)
}
}
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
switch typ {
case 1:
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
if value.Valid {
maps[key] = value.String
} else {
maps[key] = nil
}
default:
if id := ind.FieldByName(camelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
}
}
cnt++
}
if typ == 1 {
v, _ := container.(*Params)
*v = maps
}
return cnt, nil
}
// query data to []map[string]interface
func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container)
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query data to [][]interface
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container)
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container)
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(result, keyCol, valueCol)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
}
// return prepared raw statement for used in times.

View File

@ -37,6 +37,7 @@ type Ormer interface {
Rollback() error
Raw(string, ...interface{}) RawSeter
Driver() Driver
GetDB() dbQuerier
}
// insert prepared statement
@ -64,6 +65,8 @@ type QuerySeter interface {
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
}
// model to model query struct
@ -87,9 +90,11 @@ type RawSeter interface {
QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error)
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, ...string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
Prepare() (RawPreparer, error)
}
@ -109,6 +114,14 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// type DB interface {
// Begin() (*sql.Tx, error)
// Prepare(query string) (stmtQuerier, error)
// Exec(query string, args ...interface{}) (sql.Result, error)
// Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryRow(query string, args ...interface{}) *sql.Row
// }
// transaction beginner
type txer interface {
Begin() (*sql.Tx, error)
@ -139,6 +152,7 @@ type dbBaser interface {
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)