1
0
mirror of https://github.com/astaxie/beego.git synced 2025-01-22 15:37:14 +00:00

orm add full regular go type support, such as int8, uint8, byte, rune. add date/datetime timezone support very well.

This commit is contained in:
slene 2013-08-13 17:16:12 +08:00
parent deb00809a5
commit 27b84841a7
17 changed files with 527 additions and 107 deletions

141
orm/db.go
View File

@ -49,7 +49,7 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
@ -71,9 +71,22 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
case TypeCharField, TypeTextField:
value = field.String()
case TypeFloatField, TypeDecimalField:
value = field.Float()
vu := field.Interface()
if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64()
} else {
value = field.Float()
}
case TypeDateField, TypeDateTimeField:
value = field.Interface()
if t, ok := value.(time.Time); ok {
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&t, DefaultTimeLoc)
} else {
d.ins.TimeToDB(&t, tz)
}
value = t
}
default:
switch {
case fi.fieldType&IsPostiveIntegerField > 0:
@ -101,15 +114,16 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now()
if fi.fieldType == TypeDateField {
value = timeFormat(tnow, format_Date)
d.ins.TimeToDB(&tnow, DefaultTimeLoc)
} else {
value = timeFormat(tnow, format_DateTime)
d.ins.TimeToDB(&tnow, tz)
}
value = tnow
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow)
f.SetRaw(tnow.In(DefaultTimeLoc))
} else {
field.Set(reflect.ValueOf(tnow))
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
}
}
}
@ -145,8 +159,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err
}
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true)
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true, tz)
if err != nil {
return 0, err
}
@ -165,7 +179,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value)
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
@ -187,6 +201,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
d.ins.ReplaceMarks(&query)
if len(refs) == 21 {
fmt.Println(query, pkValue)
}
row := q.QueryRow(query, pkValue)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
@ -197,7 +215,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind)
}
@ -205,8 +223,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
return nil
}
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true)
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true, tz)
if err != nil {
return 0, err
}
@ -240,12 +258,12 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
setNames, setValues, err := d.collectValues(mi, ind, true, false)
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
if err != nil {
return 0, err
}
@ -269,7 +287,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, nil
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
@ -293,7 +311,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
err := d.deleteRels(q, mi, []interface{}{pkValue})
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil {
return num, err
}
@ -306,7 +324,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, nil
}
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) {
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
@ -327,7 +345,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth)
}
where, args := tables.getCondSql(cond, false)
where, args := tables.getCondSql(cond, false, tz)
values = append(values, args...)
@ -356,13 +374,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error {
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
@ -372,7 +390,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params)
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
@ -382,7 +400,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
return nil
}
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) {
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins)
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
@ -394,7 +412,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false)
where, args := tables.getCondSql(cond, false, tz)
join := tables.getJoinSql()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
@ -425,7 +443,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args)
marks := make([]string, len(args))
for i, _ := range marks {
marks[i] = "?"
}
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
@ -437,7 +459,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
if num > 0 {
err := d.deleteRels(q, mi, args)
err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
@ -451,7 +473,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) {
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -490,7 +512,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql()
@ -539,7 +561,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cacheM := make(map[string]*modelInfo)
trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)])
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
trefs = refs[len(mi.fields.dbcols):]
for _, tbl := range tables.tables {
@ -558,7 +580,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
mmi := fi.relModelInfo
field := reflect.Indirect(last.Field(fi.fieldIndex))
if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)])
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
for _, fi := range mmi.fields.fieldsReverse {
if fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
@ -592,11 +614,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil
}
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) {
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
where, args := tables.getCondSql(cond, false, tz)
tables.getOrderSql(qs.orders)
join := tables.getJoinSql()
@ -612,9 +634,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return
}
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) {
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args)
params := getFlatParams(fi, args, tz)
if len(params) == 0 {
panic(fmt.Sprintf("operator `%s` need at least one args", operator))
@ -665,11 +687,11 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params
}
func (d *dbBase) GenerateOperatorLeftCol(string, *string) {
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use
}
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -677,12 +699,12 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
field := ind.Field(fi.fieldIndex)
value, err := d.getValue(fi, val)
value, err := d.convertValueFromDB(fi, val, tz)
if err != nil {
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
}
_, err = d.setValue(fi, value, &field)
_, err = d.setFieldValue(fi, value, &field)
if err != nil {
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
@ -690,7 +712,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
}
}
func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil {
return nil, nil
}
@ -739,29 +761,32 @@ setValue:
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil {
switch v := val.(type) {
switch t := val.(type) {
case time.Time:
value = v
d.ins.TimeFromDB(&t, tz)
value = t
default:
s := StrTo(ToStr(v))
s := StrTo(ToStr(t))
str = &s
}
}
if str != nil {
s := str.String()
var format string
var (
t time.Time
err error
)
if fi.fieldType == TypeDateField {
format = format_Date
if len(s) > 10 {
s = s[:10]
}
t, err = time.ParseInLocation(format_Date, s, DefaultTimeLoc)
} else {
format = format_DateTime
if len(s) > 19 {
s = s[:19]
}
t, err = time.ParseInLocation(format_DateTime, s, tz)
}
t, err := timeParse(s, format)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
tErr = err
goto end
@ -776,12 +801,16 @@ setValue:
if str != nil {
var err error
switch fieldType {
case TypeBitField:
_, err = str.Int8()
case TypeSmallIntegerField:
_, err = str.Int16()
case TypeIntegerField:
_, err = str.Int32()
case TypeBigIntegerField:
_, err = str.Int64()
case TypePostiveBitField:
_, err = str.Uint8()
case TypePositiveSmallIntegerField:
_, err = str.Uint16()
case TypePositiveIntegerField:
@ -835,7 +864,7 @@ end:
}
func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
fieldType := fi.fieldType
isNative := fi.isFielder == false
@ -909,7 +938,7 @@ setValue:
return value, nil
}
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) {
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
maps []Params
@ -960,7 +989,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
}
}
where, args := tables.getCondSql(cond, false)
where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, qs.offset, qs.limit)
join := tables.getJoinSql()
@ -1007,7 +1036,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
value, err := d.convertValueFromDB(fi, val, tz)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
@ -1022,7 +1051,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
value, err := d.convertValueFromDB(fi, val, tz)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
@ -1036,7 +1065,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
value, err := d.convertValueFromDB(fi, val, tz)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
@ -1079,3 +1108,11 @@ func (d *dbBase) ReplaceMarks(query *string) {
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false
}
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"os"
"sync"
"time"
)
const defaultMaxIdle = 30
@ -82,6 +83,7 @@ type alias struct {
MaxIdle int
DB *sql.DB
DbBaser dbBaser
TZ *time.Location
}
func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
@ -120,6 +122,33 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
al.DB.SetMaxIdleConns(al.MaxIdle)
// orm timezone system match database
// default use Local
al.TZ = time.Local
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone")
var tz string
row.Scan(&tz)
if tz != "SYSTEM" {
t, err := time.Parse("-07:00", tz)
if err == nil {
al.TZ = t.Location()
}
}
case DR_Sqlite:
al.TZ = time.UTC
case DR_Postgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
}
}
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error())
@ -133,13 +162,22 @@ end:
}
}
func RegisterDriver(name string, typ DriverType) {
if t, ok := drivers[name]; ok == false {
drivers[name] = typ
func RegisterDriver(driverName string, typ DriverType) {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
} else {
if t != typ {
fmt.Println("name `%s` db driver already registered and is other type")
fmt.Println("driverName `%s` db driver already registered and is other type")
os.Exit(2)
}
}
}
func SetDataBaseTZ(name string, tz *time.Location) {
if al, ok := dataBaseCache.get(name); ok {
al.TZ = tz
} else {
err := fmt.Errorf("DataBase name `%s` not registered", name)
fmt.Println(err)
}
}

View File

@ -30,7 +30,7 @@ func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) {
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)

View File

@ -1,5 +1,9 @@
package orm
import (
"fmt"
)
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
@ -25,6 +29,12 @@ func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator]
}
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}

View File

@ -3,6 +3,7 @@ package orm
import (
"fmt"
"strings"
"time"
)
type dbTable struct {
@ -266,7 +267,7 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
@ -288,7 +289,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true)
w, ps := d.getCondSql(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
@ -313,10 +314,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
operator = "exact"
}
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args)
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
d.base.GenerateOperatorLeftCol(operator, &leftCol)
d.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
params = append(params, args...)

View File

@ -24,7 +24,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return
}
func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) {
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
for _, arg := range args {
@ -39,9 +39,9 @@ outFor:
case []byte:
case time.Time:
if fi != nil && fi.fieldType == TypeDateField {
arg = v.Format(format_Date)
arg = v.In(DefaultTimeLoc).Format(format_Date)
} else {
arg = v.Format(format_DateTime)
arg = v.In(tz).Format(format_DateTime)
}
default:
kind := val.Kind()
@ -65,7 +65,7 @@ outFor:
}
if len(args) > 0 {
p := getFlatParams(fi, args)
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor

View File

@ -22,12 +22,16 @@ const (
// time.Time
TypeDateTimeField
// int8
TypeBitField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint8
TypePostiveBitField
// uint16
TypePositiveSmallIntegerField
// uint32
@ -49,8 +53,8 @@ const (
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8
IsRelField = ^-RelReverseMany >> 12 << 13
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
IsRelField = ^-RelReverseMany >> 14 << 15
IsFieldType = ^-RelReverseMany<<1 + 1
)

View File

@ -327,8 +327,8 @@ checkType:
case TypeDecimalField:
d1 := digits
d2 := decimals
v1, er1 := StrTo(d1).Int16()
v2, er2 := StrTo(d2).Int16()
v1, er1 := StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int8()
if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end
@ -383,12 +383,16 @@ checkType:
_, err = v.Bool()
case TypeFloatField, TypeDecimalField:
_, err = v.Float64()
case TypeBitField:
_, err = v.Int8()
case TypeSmallIntegerField:
_, err = v.Int16()
case TypeIntegerField:
_, err = v.Int32()
case TypeBigIntegerField:
_, err = v.Int64()
case TypePostiveBitField:
_, err = v.Uint8()
case TypePositiveSmallIntegerField:
_, err = v.Uint16()
case TypePositiveIntegerField:

View File

@ -6,11 +6,60 @@ import (
"strings"
"time"
// _ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type Data struct {
Id int `orm:"auto"`
Boolean bool
Char string `orm:"size(50)"`
Text string
Date time.Time `orm:"type(date)"`
DateTime time.Time
Byte byte
Rune rune
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
Float32 float32
Float64 float64
Decimal float64 `orm:"digits(8);decimals(4)"`
}
type DataNull struct {
Id int `orm:"auto"`
Boolean bool `orm:"null"`
Char string `orm:"size(50);null"`
Text string `orm:"null"`
Date time.Time `orm:"type(date);null"`
DateTime time.Time `orm:"null"`
Byte byte `orm:"null"`
Rune rune `orm:"null"`
Int int `orm:"null"`
Int8 int8 `orm:"null"`
Int16 int16 `orm:"null"`
Int32 int32 `orm:"null"`
Int64 int64 `orm:"null"`
Uint uint `orm:"null"`
Uint8 uint8 `orm:"null"`
Uint16 uint16 `orm:"null"`
Uint32 uint32 `orm:"null"`
Uint64 uint64 `orm:"null"`
Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"`
}
type User struct {
Id int `orm:"auto"`
UserName string `orm:"size(30);unique"`
@ -111,6 +160,8 @@ var initSQLs = map[string]string{
"DROP TABLE IF EXISTS `tag`;\n" +
"DROP TABLE IF EXISTS `post_tags`;\n" +
"DROP TABLE IF EXISTS `comment`;\n" +
"DROP TABLE IF EXISTS `data`;\n" +
"DROP TABLE IF EXISTS `data_null`;\n" +
"CREATE TABLE `user_profile` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `age` smallint NOT NULL,\n" +
@ -153,6 +204,52 @@ var initSQLs = map[string]string{
" `parent_id` integer,\n" +
" `created` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `data` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool NOT NULL,\n" +
" `char` varchar(50) NOT NULL,\n" +
" `text` longtext NOT NULL,\n" +
" `date` date NOT NULL,\n" +
" `date_time` datetime NOT NULL,\n" +
" `byte` tinyint unsigned NOT NULL,\n" +
" `rune` integer NOT NULL,\n" +
" `int` integer NOT NULL,\n" +
" `int8` tinyint NOT NULL,\n" +
" `int16` smallint NOT NULL,\n" +
" `int32` integer NOT NULL,\n" +
" `int64` bigint NOT NULL,\n" +
" `uint` integer unsigned NOT NULL,\n" +
" `uint8` tinyint unsigned NULL,\n" +
" `uint16` smallint unsigned NOT NULL,\n" +
" `uint32` integer unsigned NOT NULL,\n" +
" `uint64` bigint unsigned NOT NULL,\n" +
" `float32` double precision NOT NULL,\n" +
" `float64` double precision NOT NULL,\n" +
" `decimal` numeric(8,4) NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `data_null` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool,\n" +
" `char` varchar(50),\n" +
" `text` longtext,\n" +
" `date` date,\n" +
" `date_time` datetime,\n" +
" `byte` tinyint unsigned,\n" +
" `rune` integer,\n" +
" `int` integer,\n" +
" `int8` tinyint,\n" +
" `int16` smallint,\n" +
" `int32` integer,\n" +
" `int64` bigint,\n" +
" `uint` integer unsigned,\n" +
" `uint8` tinyint unsigned,\n" +
" `uint16` smallint unsigned,\n" +
" `uint32` integer unsigned,\n" +
" `uint64` bigint unsigned,\n" +
" `float32` double precision,\n" +
" `float64` double precision,\n" +
" `decimal` numeric(8,4)\n" +
") ENGINE=INNODB;\n" +
"CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
"CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
"CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
@ -165,6 +262,8 @@ DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"age" smallint NOT NULL,
@ -207,6 +306,52 @@ CREATE TABLE "comment" (
"parent_id" integer,
"created" datetime NOT NULL
);
CREATE TABLE "data" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" datetime NOT NULL,
"byte" tinyint unsigned NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" tinyint NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" integer unsigned NOT NULL,
"uint8" tinyint unsigned NOT NULL,
"uint16" smallint unsigned NOT NULL,
"uint32" integer unsigned NOT NULL,
"uint64" bigint unsigned NOT NULL,
"float32" real NOT NULL,
"float64" real NOT NULL,
"decimal" decimal
);
CREATE TABLE "data_null" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" datetime,
"byte" tinyint unsigned,
"rune" integer,
"int" integer,
"int8" tinyint,
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" integer unsigned,
"uint8" tinyint unsigned,
"uint16" smallint unsigned,
"uint32" integer unsigned,
"uint64" bigint unsigned,
"float32" real,
"float64" real,
"decimal" decimal
);
CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
@ -220,6 +365,8 @@ DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" (
"id" serial NOT NULL PRIMARY KEY,
"age" smallint NOT NULL,
@ -262,6 +409,52 @@ CREATE TABLE "comment" (
"parent_id" integer,
"created" timestamp with time zone NOT NULL
);
CREATE TABLE "data" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" timestamp with time zone NOT NULL,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255) NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128) NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" bigint CHECK("uint" >= 0) NOT NULL,
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255) NOT NULL,
"uint16" integer CHECK("uint16" >= 0) NOT NULL,
"uint32" bigint CHECK("uint32" >= 0) NOT NULL,
"uint64" bigint CHECK("uint64" >= 0) NOT NULL,
"float32" double precision NOT NULL,
"float64" double precision NOT NULL,
"decimal" numeric(8, 4)
);
CREATE TABLE "data_null" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" timestamp with time zone,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255),
"rune" integer,
"int" integer,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128),
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" bigint CHECK("uint" >= 0),
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255),
"uint16" integer CHECK("uint16" >= 0),
"uint32" bigint CHECK("uint32" >= 0),
"uint64" bigint CHECK("uint64" >= 0),
"float32" double precision,
"float64" double precision,
"decimal" numeric(8, 4)
);
CREATE INDEX "user_profile_id" ON "user" ("profile_id");
CREATE INDEX "post_user_id" ON "post" ("user_id");
CREATE INDEX "comment_post_id" ON "comment" ("post_id");
@ -269,6 +462,10 @@ CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
`}
func init() {
// err := os.Setenv("TZ", "+00:00")
// fmt.Println(err)
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))

View File

@ -43,15 +43,19 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Kind() {
case reflect.Int8:
ft = TypeBitField
case reflect.Int16:
ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int:
ft = TypeIntegerField
case reflect.Int64:
ft = TypeBigIntegerField
case reflect.Uint8:
ft = TypePostiveBitField
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case reflect.Uint32:
case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField
case reflect.Uint64:
ft = TypePositiveBigIntegerField

View File

@ -55,7 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
func (o *orm) Read(md interface{}) error {
mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
if err != nil {
return err
}
@ -64,7 +64,7 @@ func (o *orm) Read(md interface{}) error {
func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
@ -78,7 +78,7 @@ func (o *orm) Insert(md interface{}) (int64, error) {
func (o *orm) Update(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Update(o.db, mi, ind)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}
@ -87,7 +87,7 @@ func (o *orm) Update(md interface{}) (int64, error) {
func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}

View File

@ -28,7 +28,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if name != o.mi.fullName {
panic(fmt.Sprintf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind)
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}

View File

@ -77,15 +77,15 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
}
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond)
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values)
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond)
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
func (o *querySet) PrepareInsert() (Inserter, error) {
@ -93,11 +93,11 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
}
func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
}
func (o *querySet) One(container interface{}) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
if err != nil {
return err
}
@ -111,15 +111,15 @@ func (o *querySet) One(container interface{}) error {
}
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results)
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results)
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result)
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {

View File

@ -60,7 +60,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
return o.orm.db.Exec(query, args...)
}
@ -96,7 +96,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {

View File

@ -15,6 +15,11 @@ import (
var _ = os.PathSeparator
var (
test_Date = format_Date + " -0700"
test_DateTime = format_DateTime + " -0700"
)
type T_Code int
const (
@ -141,7 +146,7 @@ func getCaller(skip int) string {
if cur == line {
flag = ">>"
}
code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i])))
code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1))
if code != "" {
codes = append(codes, code)
}
@ -158,7 +163,11 @@ func throwFail(t *testing.T, err error, args ...interface{}) {
if err != nil {
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
if len(args) > 0 {
con += fmt.Sprint(args...)
parts := make([]string, 0, len(args))
for _, arg := range args {
parts = append(parts, fmt.Sprintf("%v", arg))
}
con += " " + strings.Join(parts, ", ")
}
t.Error(con)
t.Fail()
@ -169,7 +178,11 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
if err != nil {
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
if len(args) > 0 {
con += fmt.Sprint(args...)
parts := make([]string, 0, len(args))
for _, arg := range args {
parts = append(parts, fmt.Sprintf("%v", arg))
}
con += " " + strings.Join(parts, ", ")
}
t.Error(con)
t.FailNow()
@ -177,13 +190,100 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
}
func TestModelSyntax(t *testing.T) {
mi, ok := modelCache.get("user")
user := &User{}
ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type())
mi, ok := modelCache.getByFN(fn)
throwFail(t, AssertIs(ok, T_Equal, true))
mi, ok = modelCache.get("user")
throwFail(t, AssertIs(ok, T_Equal, true))
if ok {
throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true))
}
}
func TestDataTypes(t *testing.T) {
values := map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
"Date": time.Now(),
"DateTime": time.Now(),
"Byte": byte(1<<8 - 1),
"Rune": rune(1<<31 - 1),
"Int": int(1<<31 - 1),
"Int8": int8(1<<7 - 1),
"Int16": int16(1<<15 - 1),
"Int32": int32(1<<31 - 1),
"Int64": int64(1<<63 - 1),
"Uint": uint(1<<32 - 1),
"Uint8": uint8(1<<8 - 1),
"Uint16": uint16(1<<16 - 1),
"Uint32": uint32(1<<32 - 1),
"Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported
"Float32": float32(100.1234),
"Float64": float64(100.1234),
"Decimal": float64(100.1234),
}
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range values {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, 1))
d = Data{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, T_Equal, true), value, vu)
}
}
func TestNullDataTypes(t *testing.T) {
d := DataNull{}
if IsPostgres {
// can removed when this fixed
// https://github.com/lib/pq/pull/125
d.DateTime = time.Now()
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, 1))
d = DataNull{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
d = DataNull{Id: 2}
err = dORM.Read(&d)
throwFail(t, err)
}
func TestCRUD(t *testing.T) {
profile := NewProfile()
profile.Age = 30
@ -214,8 +314,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, T_Equal, 3))
throwFail(t, AssertIs(u.IsStaff, T_Equal, true))
throwFail(t, AssertIs(u.IsActive, T_Equal, true))
throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), T_Equal, user.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), T_Equal, user.Updated.In(DefaultTimeLoc), test_DateTime))
user.UserName = "astaxie"
user.Profile = profile
@ -360,7 +460,9 @@ The program—and web server—godoc processes Go source files to extract docume
}
func TestExpr(t *testing.T) {
qs := dORM.QueryTable("User")
user := &User{}
qs := dORM.QueryTable(user)
qs = dORM.QueryTable("User")
qs = dORM.QueryTable("user")
num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
throwFail(t, err)
@ -369,6 +471,10 @@ func TestExpr(t *testing.T) {
num, err = qs.Filter("created", time.Now()).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
num, err = qs.Filter("created", time.Now().Format(format_Date)).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
func TestOperators(t *testing.T) {
@ -820,9 +926,11 @@ func TestRaw(t *testing.T) {
res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec()
throwFail(t, err)
num, err := res.RowsAffected()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if err == nil {
num, err := res.RowsAffected()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
}
}
}

View File

@ -3,6 +3,7 @@ package orm
import (
"database/sql"
"reflect"
"time"
)
type Driver interface {
@ -110,23 +111,25 @@ type txEnder interface {
}
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value) error
Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{})
GenerateOperatorLeftCol(string, *string)
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
}

View File

@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
@ -53,6 +58,11 @@ func (f StrTo) Int64() (int64, error) {
return int64(v), err
}
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
@ -85,6 +95,8 @@ func ToStr(value interface{}, args ...int) (s string) {
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
@ -93,6 +105,8 @@ func ToStr(value interface{}, args ...int) (s string) {
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32: