diff --git a/orm/db.go b/orm/db.go index f12e76fb..60e53765 100644 --- a/orm/db.go +++ b/orm/db.go @@ -1350,6 +1350,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } +func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { + return 0, nil +} + // flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true diff --git a/orm/db_alias.go b/orm/db_alias.go index d50b6ebd..22066514 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -3,7 +3,6 @@ package orm import ( "database/sql" "fmt" - "os" "reflect" "sync" "time" @@ -13,11 +12,11 @@ import ( type DriverType int const ( - _ DriverType = iota // int enum type - DR_MySQL // mysql - DR_Sqlite // sqlite - DR_Oracle // oracle - DR_Postgres // pgsql + _ DriverType = iota // int enum type + DR_MySQL // mysql + DR_Sqlite // sqlite + DR_Oracle // oracle + DR_Postgres // pgsql ) // database driver string. @@ -96,40 +95,15 @@ type alias struct { Engine string } -// Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DataSource = dataSource - - var ( - err error - ) - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - err = fmt.Errorf("driver name `%s` have not registered", driverName) - goto end - } - - if dataBaseCache.add(aliasName, al) == false { - err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName) - goto end - } - - al.DB, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - +func detectTZ(al *alias) { // orm timezone system match database // default use Local al.TZ = time.Local + if al.DriverName == "sphinx" { + return + } + switch al.Driver { case DR_MySQL: row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") @@ -173,6 +147,60 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) } } +} + +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = db + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if dataBaseCache.add(aliasName, al) == false { + return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + _, err := addAliasWthDB(aliasName, driverName, db) + return err +} + +// Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) for i, v := range params { switch i { @@ -183,39 +211,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { } } - err = al.DB.Ping() - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - end: if err != nil { - fmt.Println(err.Error()) - os.Exit(2) + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) } + + return err } // Register a database driver use specify driver name, this can be definition the driver is which database type. -func RegisterDriver(driverName string, typ DriverType) { +func RegisterDriver(driverName string, typ DriverType) error { if t, ok := drivers[driverName]; ok == false { drivers[driverName] = typ } else { if t != typ { - fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName) - os.Exit(2) + return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName) } } + return nil } // Change the database default used timezone -func SetDataBaseTZ(aliasName string, tz *time.Location) { +func SetDataBaseTZ(aliasName string, tz *time.Location) error { if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz } else { - fmt.Sprintf("DataBase name `%s` not registered\n", aliasName) - os.Exit(2) + return fmt.Errorf("DataBase name `%s` not registered\n", aliasName) } + return nil } // Change the max idle conns for *sql.DB, use specify database alias name diff --git a/orm/orm.go b/orm/orm.go index 00439399..25857fa8 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -439,6 +439,12 @@ func (o *orm) Driver() Driver { return driver(o.alias.Name) } +func (o *orm) GetDB() dbQuerier { + panic(ErrNotImplement) + // not enough + return o.db +} + // create new orm func NewOrm() Ormer { BootStrap() // execute only once @@ -450,3 +456,30 @@ func NewOrm() Ormer { } return o } + +// create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + var al *alias + + if dr, ok := drivers[driverName]; ok { + al = new(alias) + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + al.Name = aliasName + al.DriverName = driverName + + o := new(orm) + o.alias = al + + if Debug { + o.db = newDbQueryLog(o.alias, db) + } else { + o.db = db + } + + return o, nil +} diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index ad8a9374..d16f8eb5 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -197,6 +197,36 @@ func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) + return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, result, keyCol, valueCol, o.orm.alias.TZ) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) + return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, ptrStruct, keyCol, valueCol, o.orm.alias.TZ) +} + // create new QuerySeter. func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { o := new(querySet) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 3f5fb162..a968e347 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -518,7 +518,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { return cnt, nil } -func (o *rawSet) readValues(container interface{}) (int64, error) { +func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { var ( maps []Params lists []ParamsList @@ -552,20 +552,38 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { defer rs.Close() var ( - refs []interface{} - cnt int64 - cols []string + refs []interface{} + cnt int64 + cols []string + indexs []int ) + for rs.Next() { if cnt == 0 { if columns, err := rs.Columns(); err != nil { return 0, err } else { + if len(needCols) > 0 { + indexs = make([]int, 0, len(needCols)) + } else { + indexs = make([]int, 0, len(columns)) + } + cols = columns refs = make([]interface{}, len(cols)) for i, _ := range refs { var ref sql.NullString refs[i] = &ref + + if len(needCols) > 0 { + for _, c := range needCols { + if c == cols[i] { + indexs = append(indexs, i) + } + } + } else { + indexs = append(indexs, i) + } } } } @@ -577,7 +595,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { switch typ { case 1: params := make(Params, len(cols)) - for i, ref := range refs { + for _, i := range indexs { + ref := refs[i] value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) if value.Valid { params[cols[i]] = value.String @@ -588,7 +607,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { maps = append(maps, params) case 2: params := make(ParamsList, 0, len(cols)) - for _, ref := range refs { + for _, i := range indexs { + ref := refs[i] value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) if value.Valid { params = append(params, value.String) @@ -598,7 +618,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { } lists = append(lists, params) case 3: - for _, ref := range refs { + for _, i := range indexs { + ref := refs[i] value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) if value.Valid { list = append(list, value.String) @@ -623,19 +644,163 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { return cnt, nil } +func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { + var ( + maps Params + ind *reflect.Value + ) + + typ := 0 + switch container.(type) { + case *Params: + typ = 1 + default: + typ = 2 + vl := reflect.ValueOf(container) + id := reflect.Indirect(vl) + if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { + panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) + } + + ind = &id + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + var rs *sql.Rows + if r, err := o.orm.db.Query(query, args...); err != nil { + return 0, err + } else { + rs = r + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + ) + + var ( + keyIndex = -1 + valueIndex = -1 + ) + + for rs.Next() { + if cnt == 0 { + if columns, err := rs.Columns(); err != nil { + return 0, err + } else { + cols = columns + refs = make([]interface{}, len(cols)) + for i, _ := range refs { + if keyCol == cols[i] { + keyIndex = i + } + + if typ == 1 || keyIndex == i { + var ref sql.NullString + refs[i] = &ref + } else { + var ref interface{} + refs[i] = &ref + } + + if valueCol == cols[i] { + valueIndex = i + } + } + + if keyIndex == -1 || valueIndex == -1 { + panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) + } + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 { + switch typ { + case 1: + maps = make(Params) + } + } + + key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String + + switch typ { + case 1: + value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) + if value.Valid { + maps[key] = value.String + } else { + maps[key] = nil + } + + default: + if id := ind.FieldByName(camelString(key)); id.IsValid() { + o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) + } + } + + cnt++ + } + + if typ == 1 { + v, _ := container.(*Params) + *v = maps + } + + return cnt, nil +} + // query data to []map[string]interface -func (o *rawSet) Values(container *[]Params) (int64, error) { - return o.readValues(container) +func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { + return o.readValues(container, cols) } // query data to [][]interface -func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { - return o.readValues(container) +func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) } // query data to []interface -func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { - return o.readValues(container) +func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(result, keyCol, valueCol) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(ptrStruct, keyCol, valueCol) } // return prepared raw statement for used in times. diff --git a/orm/types.go b/orm/types.go index 76e53017..4361c62c 100644 --- a/orm/types.go +++ b/orm/types.go @@ -37,6 +37,7 @@ type Ormer interface { Rollback() error Raw(string, ...interface{}) RawSeter Driver() Driver + GetDB() dbQuerier } // insert prepared statement @@ -64,6 +65,8 @@ type QuerySeter interface { Values(*[]Params, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error) + RowsToMap(*Params, string, string) (int64, error) + RowsToStruct(interface{}, string, string) (int64, error) } // model to model query struct @@ -87,9 +90,11 @@ type RawSeter interface { QueryRow(...interface{}) error QueryRows(...interface{}) (int64, error) SetArgs(...interface{}) RawSeter - Values(*[]Params) (int64, error) - ValuesList(*[]ParamsList) (int64, error) - ValuesFlat(*ParamsList) (int64, error) + Values(*[]Params, ...string) (int64, error) + ValuesList(*[]ParamsList, ...string) (int64, error) + ValuesFlat(*ParamsList, ...string) (int64, error) + RowsToMap(*Params, string, string) (int64, error) + RowsToStruct(interface{}, string, string) (int64, error) Prepare() (RawPreparer, error) } @@ -109,6 +114,14 @@ type dbQuerier interface { QueryRow(query string, args ...interface{}) *sql.Row } +// type DB interface { +// Begin() (*sql.Tx, error) +// Prepare(query string) (stmtQuerier, error) +// Exec(query string, args ...interface{}) (sql.Result, error) +// Query(query string, args ...interface{}) (*sql.Rows, error) +// QueryRow(query string, args ...interface{}) *sql.Row +// } + // transaction beginner type txer interface { Begin() (*sql.Tx, error) @@ -139,6 +152,7 @@ type dbBaser interface { GenerateOperatorLeftCol(*fieldInfo, string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string)