1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-29 12:11:28 +00:00

Merge branch 'master' of https://github.com/astaxie/beego into spelling

This commit is contained in:
benlovell 2013-08-14 09:07:51 +02:00
commit 1977d87d55
23 changed files with 689 additions and 135 deletions

View File

@ -13,7 +13,7 @@ import (
"time" "time"
) )
const VERSION = "0.8.0" const VERSION = "0.9.0"
var ( var (
BeeApp *App BeeApp *App

View File

@ -1,6 +1,10 @@
# beego orm # beego orm
a powerful orm framework [![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
A powerful orm framework for go.
It is heavily influenced by Django ORM, SQLAlchemy.
now, beta, unstable, may be changing some api make your app build failed. now, beta, unstable, may be changing some api make your app build failed.
@ -14,12 +18,25 @@ Passed all test, but need more feedback.
**Features:** **Features:**
... * full go type support
* easy for usage, simple CRUD operation
* auto join with relation table
* cross DataBase compatible query
* Raw SQL query / mapper without orm model
* full test keep stable and strong
more features please read the docs
**Install:** **Install:**
go get github.com/astaxie/beego/orm go get github.com/astaxie/beego/orm
## Changelog
* 2013-08-13: update test for database types
* 2013-08-13: go type support, such as int8, uint8, byte, rune
* 2013-08-13: date / datetime timezone support very well
## Quick Start ## Quick Start
#### Simple Usage #### Simple Usage
@ -143,5 +160,3 @@ more details and examples in docs and test
- some unrealized api - some unrealized api
- examples - examples
- docs - docs
##

137
orm/db.go
View File

@ -49,7 +49,7 @@ type dbBase struct {
ins dbBaser ins dbBaser
} }
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
_, pkValue, _ := getExistPk(mi, ind) _, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders { for _, column := range mi.fields.orders {
fi := mi.fields.columns[column] fi := mi.fields.columns[column]
@ -71,9 +71,22 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
case TypeCharField, TypeTextField: case TypeCharField, TypeTextField:
value = field.String() value = field.String()
case TypeFloatField, TypeDecimalField: case TypeFloatField, TypeDecimalField:
value = field.Float() vu := field.Interface()
if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64()
} else {
value = field.Float()
}
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
value = field.Interface() value = field.Interface()
if t, ok := value.(time.Time); ok {
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&t, DefaultTimeLoc)
} else {
d.ins.TimeToDB(&t, tz)
}
value = t
}
default: default:
switch { switch {
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPostiveIntegerField > 0:
@ -101,15 +114,16 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if fi.auto_now || fi.auto_now_add && insert { if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now() tnow := time.Now()
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
value = timeFormat(tnow, format_Date) d.ins.TimeToDB(&tnow, DefaultTimeLoc)
} else { } else {
value = timeFormat(tnow, format_DateTime) d.ins.TimeToDB(&tnow, tz)
} }
value = tnow
if fi.isFielder { if fi.isFielder {
f := field.Addr().Interface().(Fielder) f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow) f.SetRaw(tnow.In(DefaultTimeLoc))
} else { } else {
field.Set(reflect.ValueOf(tnow)) field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
} }
} }
} }
@ -145,8 +159,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err return stmt, query, err
} }
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true) _, values, err := d.collectValues(mi, ind, true, true, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -165,7 +179,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value)
} }
} }
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
pkColumn, pkValue, ok := getExistPk(mi, ind) pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
@ -197,7 +211,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs) d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind) ind.Set(mind)
} }
@ -205,8 +219,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
return nil return nil
} }
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true) names, values, err := d.collectValues(mi, ind, true, true, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -240,12 +254,12 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
} }
} }
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
setNames, setValues, err := d.collectValues(mi, ind, true, false) setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -269,7 +283,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, nil return 0, nil
} }
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
@ -293,7 +307,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
ind.Field(mi.fields.pk.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
err := d.deleteRels(q, mi, []interface{}{pkValue}) err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -306,7 +320,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, nil return 0, nil
} }
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) { func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
for col, val := range params { for col, val := range params {
@ -327,7 +341,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false, tz)
values = append(values, args...) values = append(values, args...)
@ -356,13 +370,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error { func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
switch fi.onDelete { switch fi.onDelete {
case od_CASCADE: case od_CASCADE:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond) _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil { if err != nil {
return err return err
} }
@ -372,7 +386,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
if fi.onDelete == od_SET_DEFAULT { if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String() params[fi.column] = fi.initial.String()
} }
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params) _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil { if err != nil {
return err return err
} }
@ -382,7 +396,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
return nil return nil
} }
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) { func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -394,7 +408,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false, tz)
join := tables.getJoinSql() join := tables.getJoinSql()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
@ -425,7 +439,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args) marks := make([]string, len(args))
for i, _ := range marks {
marks[i] = "?"
}
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -437,7 +455,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
if num > 0 { if num > 0 {
err := d.deleteRels(q, mi, args) err := d.deleteRels(q, mi, args, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -451,7 +469,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) { func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -490,7 +508,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders) orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, offset, rlimit) limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql() join := tables.getJoinSql()
@ -539,7 +557,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cacheM := make(map[string]*modelInfo) cacheM := make(map[string]*modelInfo)
trefs := refs trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)]) d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
trefs = refs[len(mi.fields.dbcols):] trefs = refs[len(mi.fields.dbcols):]
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
@ -558,7 +576,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
mmi := fi.relModelInfo mmi := fi.relModelInfo
field := reflect.Indirect(last.Field(fi.fieldIndex)) field := reflect.Indirect(last.Field(fi.fieldIndex))
if field.IsValid() { if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)]) d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
for _, fi := range mmi.fields.fieldsReverse { for _, fi := range mmi.fields.fieldsReverse {
if fi.reverseFieldInfo.mi == lastm { if fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil { if fi.reverseFieldInfo != nil {
@ -592,11 +610,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil return cnt, nil
} }
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) { func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false, tz)
tables.getOrderSql(qs.orders) tables.getOrderSql(qs.orders)
join := tables.getJoinSql() join := tables.getJoinSql()
@ -612,9 +630,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return return
} }
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) { func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args) params := getFlatParams(fi, args, tz)
if len(params) == 0 { if len(params) == 0 {
panic(fmt.Sprintf("operator `%s` need at least one args", operator)) panic(fmt.Sprintf("operator `%s` need at least one args", operator))
@ -665,11 +683,11 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params return sql, params
} }
func (d *dbBase) GenerateOperatorLeftCol(string, *string) { func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use
} }
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols { for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -677,12 +695,12 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
field := ind.Field(fi.fieldIndex) field := ind.Field(fi.fieldIndex)
value, err := d.getValue(fi, val) value, err := d.convertValueFromDB(fi, val, tz)
if err != nil { if err != nil {
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
} }
_, err = d.setValue(fi, value, &field) _, err = d.setFieldValue(fi, value, &field)
if err != nil { if err != nil {
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
@ -690,7 +708,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
} }
} }
func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil { if val == nil {
return nil, nil return nil, nil
} }
@ -739,29 +757,32 @@ setValue:
} }
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil { if str == nil {
switch v := val.(type) { switch t := val.(type) {
case time.Time: case time.Time:
value = v d.ins.TimeFromDB(&t, tz)
value = t
default: default:
s := StrTo(ToStr(v)) s := StrTo(ToStr(t))
str = &s str = &s
} }
} }
if str != nil { if str != nil {
s := str.String() s := str.String()
var format string var (
t time.Time
err error
)
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
format = format_Date
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(format_Date, s, DefaultTimeLoc)
} else { } else {
format = format_DateTime
if len(s) > 19 { if len(s) > 19 {
s = s[:19] s = s[:19]
} }
t, err = time.ParseInLocation(format_DateTime, s, tz)
} }
t, err := timeParse(s, format)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
tErr = err tErr = err
goto end goto end
@ -776,12 +797,16 @@ setValue:
if str != nil { if str != nil {
var err error var err error
switch fieldType { switch fieldType {
case TypeBitField:
_, err = str.Int8()
case TypeSmallIntegerField: case TypeSmallIntegerField:
_, err = str.Int16() _, err = str.Int16()
case TypeIntegerField: case TypeIntegerField:
_, err = str.Int32() _, err = str.Int32()
case TypeBigIntegerField: case TypeBigIntegerField:
_, err = str.Int64() _, err = str.Int64()
case TypePostiveBitField:
_, err = str.Uint8()
case TypePositiveSmallIntegerField: case TypePositiveSmallIntegerField:
_, err = str.Uint16() _, err = str.Uint16()
case TypePositiveIntegerField: case TypePositiveIntegerField:
@ -835,7 +860,7 @@ end:
} }
func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.fieldType
isNative := fi.isFielder == false isNative := fi.isFielder == false
@ -909,7 +934,7 @@ setValue:
return value, nil return value, nil
} }
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) { func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
maps []Params maps []Params
@ -960,7 +985,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
} }
} }
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders) orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, qs.offset, qs.limit) limit := tables.getLimitSql(mi, qs.offset, qs.limit)
join := tables.getJoinSql() join := tables.getJoinSql()
@ -1007,7 +1032,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface() val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val) value, err := d.convertValueFromDB(fi, val, tz)
if err != nil { if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
} }
@ -1022,7 +1047,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface() val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val) value, err := d.convertValueFromDB(fi, val, tz)
if err != nil { if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
} }
@ -1036,7 +1061,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
val := reflect.Indirect(reflect.ValueOf(ref)).Interface() val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val) value, err := d.convertValueFromDB(fi, val, tz)
if err != nil { if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
} }
@ -1079,3 +1104,11 @@ func (d *dbBase) ReplaceMarks(query *string) {
func (d *dbBase) HasReturningID(*modelInfo, *string) bool { func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"sync" "sync"
"time"
) )
const defaultMaxIdle = 30 const defaultMaxIdle = 30
@ -82,6 +83,7 @@ type alias struct {
MaxIdle int MaxIdle int
DB *sql.DB DB *sql.DB
DbBaser dbBaser DbBaser dbBaser
TZ *time.Location
} }
func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
@ -120,6 +122,33 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
al.DB.SetMaxIdleConns(al.MaxIdle) al.DB.SetMaxIdleConns(al.MaxIdle)
// orm timezone system match database
// default use Local
al.TZ = time.Local
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone")
var tz string
row.Scan(&tz)
if tz != "SYSTEM" {
t, err := time.Parse("-07:00", tz)
if err == nil {
al.TZ = t.Location()
}
}
case DR_Sqlite:
al.TZ = time.UTC
case DR_Postgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
}
}
err = al.DB.Ping() err = al.DB.Ping()
if err != nil { if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error()) err = fmt.Errorf("register db `%s`, %s", name, err.Error())
@ -133,13 +162,22 @@ end:
} }
} }
func RegisterDriver(name string, typ DriverType) { func RegisterDriver(driverName string, typ DriverType) {
if t, ok := drivers[name]; ok == false { if t, ok := drivers[driverName]; ok == false {
drivers[name] = typ drivers[driverName] = typ
} else { } else {
if t != typ { if t != typ {
fmt.Println("name `%s` db driver already registered and is other type") fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
os.Exit(2) os.Exit(2)
} }
} }
} }
func SetDataBaseTZ(name string, tz *time.Location) {
if al, ok := dataBaseCache.get(name); ok {
al.TZ = tz
} else {
fmt.Sprintf("DataBase name `%s` not registered\n", name)
os.Exit(2)
}
}

View File

@ -30,7 +30,7 @@ func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) { func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator { switch operator {
case "contains", "startswith", "endswith": case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol) *leftCol = fmt.Sprintf("%s::text", *leftCol)

View File

@ -1,5 +1,9 @@
package orm package orm
import (
"fmt"
)
var sqliteOperators = map[string]string{ var sqliteOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'", "iexact": "LIKE ? ESCAPE '\\'",
@ -25,6 +29,12 @@ func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
func (d *dbBaseSqlite) SupportUpdateJoin() bool { func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false return false
} }

View File

@ -3,6 +3,7 @@ package orm
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
) )
type dbTable struct { type dbTable struct {
@ -266,7 +267,7 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return return
} }
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
} }
@ -288,7 +289,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
where += "NOT " where += "NOT "
} }
if p.isCond { if p.isCond {
w, ps := d.getCondSql(p.cond, true) w, ps := d.getCondSql(p.cond, true, tz)
if w != "" { if w != "" {
w = fmt.Sprintf("( %s) ", w) w = fmt.Sprintf("( %s) ", w)
} }
@ -313,10 +314,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
operator = "exact" operator = "exact"
} }
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args) operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
d.base.GenerateOperatorLeftCol(operator, &leftCol) d.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql) where += fmt.Sprintf("%s %s ", leftCol, operSql)
params = append(params, args...) params = append(params, args...)

View File

@ -24,7 +24,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return return
} }
func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) { func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor: outFor:
for _, arg := range args { for _, arg := range args {
@ -39,9 +39,9 @@ outFor:
case []byte: case []byte:
case time.Time: case time.Time:
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.fieldType == TypeDateField {
arg = v.Format(format_Date) arg = v.In(DefaultTimeLoc).Format(format_Date)
} else { } else {
arg = v.Format(format_DateTime) arg = v.In(tz).Format(format_DateTime)
} }
default: default:
kind := val.Kind() kind := val.Kind()
@ -65,7 +65,7 @@ outFor:
} }
if len(args) > 0 { if len(args) > 0 {
p := getFlatParams(fi, args) p := getFlatParams(fi, args, tz)
params = append(params, p...) params = append(params, p...)
} }
continue outFor continue outFor

View File

@ -164,27 +164,91 @@ type Profile struct {
``` ```
## Struct Field 类型与数据库的对应 ## 模型字段与数据库类型的对应
现在 orm 支持下面的字段形式 在此列出 orm 推荐的对应数据库类型,自动建表功能也会以此为标准。
| go type | field type | mysql type 默认所有的字段都是 **NOT NULL**
| :--- | :--- | :---
| bool | TypeBooleanField | tinyint
| string | TypeCharField | varchar
| string | TypeTextField | longtext
| time.Time | TypeDateField | date
| time.TIme | TypeDateTimeField | datetime
| int16 |TypeSmallIntegerField | int(4)
| int, int32 |TypeIntegerField | int(11)
| int64 |TypeBigIntegerField | bigint(20)
| uint, uint16 |TypePositiveSmallIntegerField | int(4) unsigned
| uint32 |TypePositiveIntegerField | int(11) unsigned
| uint64 |TypePositiveBigIntegerField | bigint(20) unsigned
| float32, float64 | TypeFloatField | double
| float32, float64 | TypeDecimalField | double(digits, decimals)
关系型的字段,其字段类型取决于对应的主键。 #### MySQL
| go |mysql
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | longtext
| time.Time - 设置 type 为 date 时 | date
| time.TIme | datetime
| byte | tinyint unsigned
| rune | integer
| int | integer
| int8 | tinyint
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | integer unsigned
| uint8 | tinyint unsigned
| uint16 | smallint unsigned
| uint32 | integer unsigned
| uint64 | bigint unsigned
| float32 | double precision
| float64 | double precision
| float64 - 设置 digits, decimals 时 | numeric(digits, decimals)
#### Sqlite3
| go | sqlite3
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | text
| time.Time - 设置 type 为 date 时 | date
| time.TIme | datetime
| byte | tinyint unsigned
| rune | integer
| int | integer
| int8 | tinyint
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | integer unsigned
| uint8 | tinyint unsigned
| uint16 | smallint unsigned
| uint32 | integer unsigned
| uint64 | bigint unsigned
| float32 | real
| float64 | real
| float64 - 设置 digits, decimals 时 | decimal
#### PostgreSQL
| go | postgres
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | text
| time.Time - 设置 type 为 date 时 | date
| time.TIme | timestamp with time zone
| byte | smallint CHECK("column" >= 0 AND "column" <= 255)
| rune | integer
| int | integer
| int8 | smallint CHECK("column" >= -127 AND "column" <= 128)
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | bigint CHECK("column" >= 0)
| uint8 | smallint CHECK("column" >= 0 AND "column" <= 255)
| uint16 | integer CHECK("column" >= 0)
| uint32 | bigint CHECK("column" >= 0)
| uint64 | bigint CHECK("column" >= 0)
| float32 | double precision
| float64 | double precision
| float64 - 设置 digits, decimals 时 | numeric(digits, decimals)
## 关系型字段
其字段类型取决于对应的主键。
* RelForeignKey * RelForeignKey
* RelOneToOne * RelOneToOne

View File

@ -80,7 +80,7 @@ import (
#### RegisterDriver #### RegisterDriver
三种数据库类型 三种默认数据库类型
```go ```go
orm.DR_MySQL orm.DR_MySQL
@ -93,7 +93,7 @@ orm.DR_Postgres
// 参数2 数据库类型 // 参数2 数据库类型
// 这个用来设置 driverName 对应的数据库类型 // 这个用来设置 driverName 对应的数据库类型
// mysql / sqlite3 / postgres 这三种是默认已经注册过的,所以可以无需设置 // mysql / sqlite3 / postgres 这三种是默认已经注册过的,所以可以无需设置
orm.RegisterDriver("mysql", orm.DR_MySQL) orm.RegisterDriver("mymysql", orm.DR_MySQL)
``` ```
#### RegisterDataBase #### RegisterDataBase
@ -108,6 +108,56 @@ orm 必须注册一个名称为 `default` 的数据库,用以作为默认使
orm.RegisterDataBase("default", "mysql", "root:root@/orm_test?charset=utf8", 30) orm.RegisterDataBase("default", "mysql", "root:root@/orm_test?charset=utf8", 30)
``` ```
#### 时区设置
orm 默认使用 time.Local 本地时区
* 作用于 orm 自动创建的时间
* 从数据库中取回的时间转换成 orm 本地时间
如果需要的话,你也可以进行更改
```go
// 设置为 UTC 时间
orm.DefaultTimeLoc = time.UTC
```
orm 在进行 RegisterDataBase 的同时,会获取数据库使用的时区,然后在 time.Time 类型存取的时做相应转换,以匹配时间系统,从而保证时间不会出错。
**注意:** 鉴于 Sqlite3 的设计,存取默认都为 UTC 时间
## RegisterModel
如果使用 orm.QuerySeter 进行高级查询的话,这个是必须的。
反之,如果只使用 Raw 查询和 map struct是无需这一步的。您可以去查看 [Raw SQL 查询](Raw.md)
将你定义的 Model 进行注册,最佳设计是有单独的 models.go 文件,在他的 init 函数中进行注册。
迷你版 models.go
```go
package main
import "github.com/astaxie/beego/orm"
type User struct {
Id int `orm:"auto"`
name string
}
func init(){
orm.RegisterModel(new(User))
}
```
RegisterModel 也可以同时注册多个 model
```go
orm.RegisterModel(new(User), new(Profile), new(Post))
```
## ORM 接口使用 ## ORM 接口使用
使用 orm 必然接触的 Ormer 接口,我们来熟悉一下 使用 orm 必然接触的 Ormer 接口,我们来熟悉一下

View File

@ -15,7 +15,7 @@ qs = o.QueryTable(user) // 返回 QuerySeter
``` ```
## expr ## expr
QuerySeter 中用于描述字段和 sql 操作符使用简单的 expr 查询方法 QuerySeter 中用于描述字段和 sql 操作符使用简单的 expr 查询方法
字段组合的前后顺序依照表的关系,比如 User 表拥有 Profile 的外键,那么对 User 表查询对应的 Profile.Age 为条件,则使用 `Profile__Age` 注意,字段的分隔符号使用双下划线 `__`,除了描述字段, expr 的尾部可以增加操作符以执行对应的 sql 操作。比如 `Profile__Age__gt` 代表 Profile.Age > 18 的条件查询。 字段组合的前后顺序依照表的关系,比如 User 表拥有 Profile 的外键,那么对 User 表查询对应的 Profile.Age 为条件,则使用 `Profile__Age` 注意,字段的分隔符号使用双下划线 `__`,除了描述字段, expr 的尾部可以增加操作符以执行对应的 sql 操作。比如 `Profile__Age__gt` 代表 Profile.Age > 18 的条件查询。

View File

@ -1,8 +1,11 @@
## 文档目录 ## 文档目录
1. [Orm 使用方法](Orm.md) 1. [Orm 使用方法](Orm.md)
- [数据库的设置](Orm.md#数据库的设置) - [数据库的设置](Orm.md#数据库的设置)
* [驱动类型设置](Orm.md#registerdriver)
* [参数设置](Orm.md#registerdataBase)
* [时区设置](Orm.md#时区设置)
- [注册 ORM 使用的模型](Orm.md#registermodel)
- [ORM 接口使用](Orm.md#orm-接口使用) - [ORM 接口使用](Orm.md#orm-接口使用)
- [调试模式打印查询语句](Orm.md#调试模式打印查询语句) - [调试模式打印查询语句](Orm.md#调试模式打印查询语句)
2. [对象的CRUD操作](Object.md) 2. [对象的CRUD操作](Object.md)
@ -15,6 +18,12 @@
6. [模型定义](Models.md) 6. [模型定义](Models.md)
- [Struct Tag 设置参数](Models.md#struct-tag-设置参数) - [Struct Tag 设置参数](Models.md#struct-tag-设置参数)
- [表关系设置](Models.md#表关系设置) - [表关系设置](Models.md#表关系设置)
- [Struct Field 类型与数据库的对应](Models.md#struct-field-类型与数据库的对应) - [模型字段与数据库类型的对应](Models.md#模型字段与数据库类型的对应)
7. Custom Fields 7. Custom Fields
8. Faq 8. Faq
### 文档更新记录
* 2013-08-13: ORM 的 [时区设置](Orm.md#时区设置)
* 2013-08-13: [模型字段与数据库类型的对应](Models.md#模型字段与数据库类型的对应) 推荐的数据库对应使用的类型

View File

@ -22,12 +22,16 @@ const (
// time.Time // time.Time
TypeDateTimeField TypeDateTimeField
// int8
TypeBitField
// int16 // int16
TypeSmallIntegerField TypeSmallIntegerField
// int32 // int32
TypeIntegerField TypeIntegerField
// int64 // int64
TypeBigIntegerField TypeBigIntegerField
// uint8
TypePostiveBitField
// uint16 // uint16
TypePositiveSmallIntegerField TypePositiveSmallIntegerField
// uint32 // uint32
@ -49,8 +53,8 @@ const (
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8 IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
IsRelField = ^-RelReverseMany >> 12 << 13 IsRelField = ^-RelReverseMany >> 14 << 15
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )

View File

@ -327,8 +327,8 @@ checkType:
case TypeDecimalField: case TypeDecimalField:
d1 := digits d1 := digits
d2 := decimals d2 := decimals
v1, er1 := StrTo(d1).Int16() v1, er1 := StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int16() v2, er2 := StrTo(d2).Int8()
if er1 != nil || er2 != nil { if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end goto end
@ -383,12 +383,16 @@ checkType:
_, err = v.Bool() _, err = v.Bool()
case TypeFloatField, TypeDecimalField: case TypeFloatField, TypeDecimalField:
_, err = v.Float64() _, err = v.Float64()
case TypeBitField:
_, err = v.Int8()
case TypeSmallIntegerField: case TypeSmallIntegerField:
_, err = v.Int16() _, err = v.Int16()
case TypeIntegerField: case TypeIntegerField:
_, err = v.Int32() _, err = v.Int32()
case TypeBigIntegerField: case TypeBigIntegerField:
_, err = v.Int64() _, err = v.Int64()
case TypePostiveBitField:
_, err = v.Uint8()
case TypePositiveSmallIntegerField: case TypePositiveSmallIntegerField:
_, err = v.Uint16() _, err = v.Uint16()
case TypePositiveIntegerField: case TypePositiveIntegerField:

View File

@ -6,11 +6,60 @@ import (
"strings" "strings"
"time" "time"
// _ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
type Data struct {
Id int `orm:"auto"`
Boolean bool
Char string `orm:"size(50)"`
Text string
Date time.Time `orm:"type(date)"`
DateTime time.Time
Byte byte
Rune rune
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
Float32 float32
Float64 float64
Decimal float64 `orm:"digits(8);decimals(4)"`
}
type DataNull struct {
Id int `orm:"auto"`
Boolean bool `orm:"null"`
Char string `orm:"size(50);null"`
Text string `orm:"null"`
Date time.Time `orm:"type(date);null"`
DateTime time.Time `orm:"null"`
Byte byte `orm:"null"`
Rune rune `orm:"null"`
Int int `orm:"null"`
Int8 int8 `orm:"null"`
Int16 int16 `orm:"null"`
Int32 int32 `orm:"null"`
Int64 int64 `orm:"null"`
Uint uint `orm:"null"`
Uint8 uint8 `orm:"null"`
Uint16 uint16 `orm:"null"`
Uint32 uint32 `orm:"null"`
Uint64 uint64 `orm:"null"`
Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"`
}
type User struct { type User struct {
Id int `orm:"auto"` Id int `orm:"auto"`
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
@ -111,6 +160,8 @@ var initSQLs = map[string]string{
"DROP TABLE IF EXISTS `tag`;\n" + "DROP TABLE IF EXISTS `tag`;\n" +
"DROP TABLE IF EXISTS `post_tags`;\n" + "DROP TABLE IF EXISTS `post_tags`;\n" +
"DROP TABLE IF EXISTS `comment`;\n" + "DROP TABLE IF EXISTS `comment`;\n" +
"DROP TABLE IF EXISTS `data`;\n" +
"DROP TABLE IF EXISTS `data_null`;\n" +
"CREATE TABLE `user_profile` (\n" + "CREATE TABLE `user_profile` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `age` smallint NOT NULL,\n" + " `age` smallint NOT NULL,\n" +
@ -153,6 +204,52 @@ var initSQLs = map[string]string{
" `parent_id` integer,\n" + " `parent_id` integer,\n" +
" `created` datetime NOT NULL\n" + " `created` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" + ") ENGINE=INNODB;\n" +
"CREATE TABLE `data` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool NOT NULL,\n" +
" `char` varchar(50) NOT NULL,\n" +
" `text` longtext NOT NULL,\n" +
" `date` date NOT NULL,\n" +
" `date_time` datetime NOT NULL,\n" +
" `byte` tinyint unsigned NOT NULL,\n" +
" `rune` integer NOT NULL,\n" +
" `int` integer NOT NULL,\n" +
" `int8` tinyint NOT NULL,\n" +
" `int16` smallint NOT NULL,\n" +
" `int32` integer NOT NULL,\n" +
" `int64` bigint NOT NULL,\n" +
" `uint` integer unsigned NOT NULL,\n" +
" `uint8` tinyint unsigned NULL,\n" +
" `uint16` smallint unsigned NOT NULL,\n" +
" `uint32` integer unsigned NOT NULL,\n" +
" `uint64` bigint unsigned NOT NULL,\n" +
" `float32` double precision NOT NULL,\n" +
" `float64` double precision NOT NULL,\n" +
" `decimal` numeric(8,4) NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `data_null` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool,\n" +
" `char` varchar(50),\n" +
" `text` longtext,\n" +
" `date` date,\n" +
" `date_time` datetime,\n" +
" `byte` tinyint unsigned,\n" +
" `rune` integer,\n" +
" `int` integer,\n" +
" `int8` tinyint,\n" +
" `int16` smallint,\n" +
" `int32` integer,\n" +
" `int64` bigint,\n" +
" `uint` integer unsigned,\n" +
" `uint8` tinyint unsigned,\n" +
" `uint16` smallint unsigned,\n" +
" `uint32` integer unsigned,\n" +
" `uint64` bigint unsigned,\n" +
" `float32` double precision,\n" +
" `float64` double precision,\n" +
" `decimal` numeric(8,4)\n" +
") ENGINE=INNODB;\n" +
"CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" + "CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
"CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" + "CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
"CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" + "CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
@ -165,6 +262,8 @@ DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag"; DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags"; DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment"; DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" ( CREATE TABLE "user_profile" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"age" smallint NOT NULL, "age" smallint NOT NULL,
@ -207,6 +306,52 @@ CREATE TABLE "comment" (
"parent_id" integer, "parent_id" integer,
"created" datetime NOT NULL "created" datetime NOT NULL
); );
CREATE TABLE "data" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" datetime NOT NULL,
"byte" tinyint unsigned NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" tinyint NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" integer unsigned NOT NULL,
"uint8" tinyint unsigned NOT NULL,
"uint16" smallint unsigned NOT NULL,
"uint32" integer unsigned NOT NULL,
"uint64" bigint unsigned NOT NULL,
"float32" real NOT NULL,
"float64" real NOT NULL,
"decimal" decimal
);
CREATE TABLE "data_null" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" datetime,
"byte" tinyint unsigned,
"rune" integer,
"int" integer,
"int8" tinyint,
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" integer unsigned,
"uint8" tinyint unsigned,
"uint16" smallint unsigned,
"uint32" integer unsigned,
"uint64" bigint unsigned,
"float32" real,
"float64" real,
"decimal" decimal
);
CREATE INDEX "user_141c6eec" ON "user" ("profile_id"); CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
CREATE INDEX "post_fbfc09f1" ON "post" ("user_id"); CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id"); CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
@ -220,6 +365,8 @@ DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag"; DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags"; DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment"; DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" ( CREATE TABLE "user_profile" (
"id" serial NOT NULL PRIMARY KEY, "id" serial NOT NULL PRIMARY KEY,
"age" smallint NOT NULL, "age" smallint NOT NULL,
@ -262,6 +409,52 @@ CREATE TABLE "comment" (
"parent_id" integer, "parent_id" integer,
"created" timestamp with time zone NOT NULL "created" timestamp with time zone NOT NULL
); );
CREATE TABLE "data" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" timestamp with time zone NOT NULL,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255) NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128) NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" bigint CHECK("uint" >= 0) NOT NULL,
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255) NOT NULL,
"uint16" integer CHECK("uint16" >= 0) NOT NULL,
"uint32" bigint CHECK("uint32" >= 0) NOT NULL,
"uint64" bigint CHECK("uint64" >= 0) NOT NULL,
"float32" double precision NOT NULL,
"float64" double precision NOT NULL,
"decimal" numeric(8, 4)
);
CREATE TABLE "data_null" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" timestamp with time zone,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255),
"rune" integer,
"int" integer,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128),
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" bigint CHECK("uint" >= 0),
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255),
"uint16" integer CHECK("uint16" >= 0),
"uint32" bigint CHECK("uint32" >= 0),
"uint64" bigint CHECK("uint64" >= 0),
"float32" double precision,
"float64" double precision,
"decimal" numeric(8, 4)
);
CREATE INDEX "user_profile_id" ON "user" ("profile_id"); CREATE INDEX "user_profile_id" ON "user" ("profile_id");
CREATE INDEX "post_user_id" ON "post" ("user_id"); CREATE INDEX "post_user_id" ON "post" ("user_id");
CREATE INDEX "comment_post_id" ON "comment" ("post_id"); CREATE INDEX "comment_post_id" ON "comment" ("post_id");
@ -269,6 +462,10 @@ CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
`} `}
func init() { func init() {
// err := os.Setenv("TZ", "+00:00")
// fmt.Println(err)
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))

View File

@ -43,15 +43,19 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
case reflect.Int8:
ft = TypeBitField
case reflect.Int16: case reflect.Int16:
ft = TypeSmallIntegerField ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int: case reflect.Int32, reflect.Int:
ft = TypeIntegerField ft = TypeIntegerField
case reflect.Int64: case reflect.Int64:
ft = TypeBigIntegerField ft = TypeBigIntegerField
case reflect.Uint8:
ft = TypePostiveBitField
case reflect.Uint16: case reflect.Uint16:
ft = TypePositiveSmallIntegerField ft = TypePositiveSmallIntegerField
case reflect.Uint32: case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField ft = TypePositiveIntegerField
case reflect.Uint64: case reflect.Uint64:
ft = TypePositiveBigIntegerField ft = TypePositiveBigIntegerField

View File

@ -55,7 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
func (o *orm) Read(md interface{}) error { func (o *orm) Read(md interface{}) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return err return err
} }
@ -64,7 +64,7 @@ func (o *orm) Read(md interface{}) error {
func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
@ -78,7 +78,7 @@ func (o *orm) Insert(md interface{}) (int64, error) {
func (o *orm) Update(md interface{}) (int64, error) { func (o *orm) Update(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Update(o.db, mi, ind) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -87,7 +87,7 @@ func (o *orm) Update(md interface{}) (int64, error) {
func (o *orm) Delete(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }

View File

@ -28,7 +28,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if name != o.mi.fullName { if name != o.mi.fullName {
panic(fmt.Sprintf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) panic(fmt.Sprintf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
} }
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind) id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }

View File

@ -77,15 +77,15 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
} }
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
func (o *querySet) Update(values Params) (int64, error) { func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values) return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
} }
func (o *querySet) Delete() (int64, error) { func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond) return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
func (o *querySet) PrepareInsert() (Inserter, error) { func (o *querySet) PrepareInsert() (Inserter, error) {
@ -93,11 +93,11 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
} }
func (o *querySet) All(container interface{}) (int64, error) { func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
} }
func (o *querySet) One(container interface{}) error { func (o *querySet) One(container interface{}) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
if err != nil { if err != nil {
return err return err
} }
@ -111,15 +111,15 @@ func (o *querySet) One(container interface{}) error {
} }
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
} }
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {

View File

@ -60,7 +60,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
return o.orm.db.Exec(query, args...) return o.orm.db.Exec(query, args...)
} }
@ -96,7 +96,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil { if r, err := o.orm.db.Query(query, args...); err != nil {

View File

@ -15,6 +15,11 @@ import (
var _ = os.PathSeparator var _ = os.PathSeparator
var (
test_Date = format_Date + " -0700"
test_DateTime = format_DateTime + " -0700"
)
type T_Code int type T_Code int
const ( const (
@ -141,7 +146,7 @@ func getCaller(skip int) string {
if cur == line { if cur == line {
flag = ">>" flag = ">>"
} }
code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i]))) code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1))
if code != "" { if code != "" {
codes = append(codes, code) codes = append(codes, code)
} }
@ -158,7 +163,11 @@ func throwFail(t *testing.T, err error, args ...interface{}) {
if err != nil { if err != nil {
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
if len(args) > 0 { if len(args) > 0 {
con += fmt.Sprint(args...) parts := make([]string, 0, len(args))
for _, arg := range args {
parts = append(parts, fmt.Sprintf("%v", arg))
}
con += " " + strings.Join(parts, ", ")
} }
t.Error(con) t.Error(con)
t.Fail() t.Fail()
@ -169,7 +178,11 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
if err != nil { if err != nil {
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
if len(args) > 0 { if len(args) > 0 {
con += fmt.Sprint(args...) parts := make([]string, 0, len(args))
for _, arg := range args {
parts = append(parts, fmt.Sprintf("%v", arg))
}
con += " " + strings.Join(parts, ", ")
} }
t.Error(con) t.Error(con)
t.FailNow() t.FailNow()
@ -177,13 +190,100 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
} }
func TestModelSyntax(t *testing.T) { func TestModelSyntax(t *testing.T) {
mi, ok := modelCache.get("user") user := &User{}
ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type())
mi, ok := modelCache.getByFN(fn)
throwFail(t, AssertIs(ok, T_Equal, true))
mi, ok = modelCache.get("user")
throwFail(t, AssertIs(ok, T_Equal, true)) throwFail(t, AssertIs(ok, T_Equal, true))
if ok { if ok {
throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true)) throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true))
} }
} }
func TestDataTypes(t *testing.T) {
values := map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
"Date": time.Now(),
"DateTime": time.Now(),
"Byte": byte(1<<8 - 1),
"Rune": rune(1<<31 - 1),
"Int": int(1<<31 - 1),
"Int8": int8(1<<7 - 1),
"Int16": int16(1<<15 - 1),
"Int32": int32(1<<31 - 1),
"Int64": int64(1<<63 - 1),
"Uint": uint(1<<32 - 1),
"Uint8": uint8(1<<8 - 1),
"Uint16": uint16(1<<16 - 1),
"Uint32": uint32(1<<32 - 1),
"Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported
"Float32": float32(100.1234),
"Float64": float64(100.1234),
"Decimal": float64(100.1234),
}
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range values {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, 1))
d = Data{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, T_Equal, true), value, vu)
}
}
func TestNullDataTypes(t *testing.T) {
d := DataNull{}
if IsPostgres {
// can removed when this fixed
// https://github.com/lib/pq/pull/125
d.DateTime = time.Now()
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, 1))
d = DataNull{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
d = DataNull{Id: 2}
err = dORM.Read(&d)
throwFail(t, err)
}
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
profile := NewProfile() profile := NewProfile()
profile.Age = 30 profile.Age = 30
@ -214,8 +314,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, T_Equal, 3)) throwFail(t, AssertIs(u.Status, T_Equal, 3))
throwFail(t, AssertIs(u.IsStaff, T_Equal, true)) throwFail(t, AssertIs(u.IsStaff, T_Equal, true))
throwFail(t, AssertIs(u.IsActive, T_Equal, true)) throwFail(t, AssertIs(u.IsActive, T_Equal, true))
throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date)) throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), T_Equal, user.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime)) throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), T_Equal, user.Updated.In(DefaultTimeLoc), test_DateTime))
user.UserName = "astaxie" user.UserName = "astaxie"
user.Profile = profile user.Profile = profile
@ -360,7 +460,9 @@ The program—and web server—godoc processes Go source files to extract docume
} }
func TestExpr(t *testing.T) { func TestExpr(t *testing.T) {
qs := dORM.QueryTable("User") user := &User{}
qs := dORM.QueryTable(user)
qs = dORM.QueryTable("User")
qs = dORM.QueryTable("user") qs = dORM.QueryTable("user")
num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
throwFail(t, err) throwFail(t, err)
@ -369,6 +471,10 @@ func TestExpr(t *testing.T) {
num, err = qs.Filter("created", time.Now()).Count() num, err = qs.Filter("created", time.Now()).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3)) throwFail(t, AssertIs(num, T_Equal, 3))
num, err = qs.Filter("created", time.Now().Format(format_Date)).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
} }
func TestOperators(t *testing.T) { func TestOperators(t *testing.T) {
@ -820,9 +926,11 @@ func TestRaw(t *testing.T) {
res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec()
throwFail(t, err) throwFail(t, err)
num, err := res.RowsAffected() if err == nil {
throwFail(t, err) num, err := res.RowsAffected()
throwFail(t, AssertIs(num, T_Equal, 3)) throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
} }
} }
} }

View File

@ -3,6 +3,7 @@ package orm
import ( import (
"database/sql" "database/sql"
"reflect" "reflect"
"time"
) )
type Driver interface { type Driver interface {
@ -110,23 +111,25 @@ type txEnder interface {
} }
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
SupportUpdateJoin() bool SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string OperatorSql(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{}) GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool HasReturningID(*modelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
} }

View File

@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
@ -53,6 +58,11 @@ func (f StrTo) Int64() (int64, error) {
return int64(v), err return int64(v), err
} }
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
@ -85,6 +95,8 @@ func ToStr(value interface{}, args ...int) (s string) {
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int: case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16: case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32: case int32:
@ -93,6 +105,8 @@ func ToStr(value interface{}, args ...int) (s string) {
s = strconv.FormatInt(v, argInt(args).Get(0, 10)) s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint: case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16: case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32: case uint32: