From 27b84841a7ba80236256b2090337e98feed52fd3 Mon Sep 17 00:00:00 2001 From: slene Date: Tue, 13 Aug 2013 17:16:12 +0800 Subject: [PATCH] orm add full regular go type support, such as int8, uint8, byte, rune. add date/datetime timezone support very well. --- orm/db.go | 141 +++++++++++++++++++------------ orm/db_alias.go | 46 +++++++++- orm/db_postgres.go | 2 +- orm/db_sqlite.go | 10 +++ orm/db_tables.go | 9 +- orm/db_utils.go | 8 +- orm/models_fields.go | 8 +- orm/models_info_f.go | 8 +- orm/models_test.go | 197 +++++++++++++++++++++++++++++++++++++++++++ orm/models_utils.go | 6 +- orm/orm.go | 8 +- orm/orm_object.go | 2 +- orm/orm_queryset.go | 16 ++-- orm/orm_raw.go | 4 +- orm/orm_test.go | 128 +++++++++++++++++++++++++--- orm/types.go | 27 +++--- orm/utils.go | 14 +++ 17 files changed, 527 insertions(+), 107 deletions(-) diff --git a/orm/db.go b/orm/db.go index fbee123d..ac28f500 100644 --- a/orm/db.go +++ b/orm/db.go @@ -49,7 +49,7 @@ type dbBase struct { ins dbBaser } -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { _, pkValue, _ := getExistPk(mi, ind) for _, column := range mi.fields.orders { fi := mi.fields.columns[column] @@ -71,9 +71,22 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, case TypeCharField, TypeTextField: value = field.String() case TypeFloatField, TypeDecimalField: - value = field.Float() + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } case TypeDateField, TypeDateTimeField: value = field.Interface() + if t, ok := value.(time.Time); ok { + if fi.fieldType == TypeDateField { + d.ins.TimeToDB(&t, DefaultTimeLoc) + } else { + d.ins.TimeToDB(&t, tz) + } + value = t + } default: switch { case fi.fieldType&IsPostiveIntegerField > 0: @@ -101,15 +114,16 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, if fi.auto_now || fi.auto_now_add && insert { tnow := time.Now() if fi.fieldType == TypeDateField { - value = timeFormat(tnow, format_Date) + d.ins.TimeToDB(&tnow, DefaultTimeLoc) } else { - value = timeFormat(tnow, format_DateTime) + d.ins.TimeToDB(&tnow, tz) } + value = tnow if fi.isFielder { f := field.Addr().Interface().(Fielder) - f.SetRaw(tnow) + f.SetRaw(tnow.In(DefaultTimeLoc)) } else { - field.Set(reflect.ValueOf(tnow)) + field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) } } } @@ -145,8 +159,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, return stmt, query, err } -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { - _, values, err := d.collectValues(mi, ind, true, true) +func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + _, values, err := d.collectValues(mi, ind, true, true, tz) if err != nil { return 0, err } @@ -165,7 +179,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) } } -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error { pkColumn, pkValue, ok := getExistPk(mi, ind) if ok == false { return ErrMissPK @@ -187,6 +201,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { d.ins.ReplaceMarks(&query) + if len(refs) == 21 { + fmt.Println(query, pkValue) + } + row := q.QueryRow(query, pkValue) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { @@ -197,7 +215,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { elm := reflect.New(mi.addrField.Elem().Type()) mind := reflect.Indirect(elm) - d.setColsValues(mi, &mind, mi.fields.dbcols, refs) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) ind.Set(mind) } @@ -205,8 +223,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { return nil } -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { - names, values, err := d.collectValues(mi, ind, true, true) +func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names, values, err := d.collectValues(mi, ind, true, true, tz) if err != nil { return 0, err } @@ -240,12 +258,12 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e } } -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { return 0, ErrMissPK } - setNames, setValues, err := d.collectValues(mi, ind, true, false) + setNames, setValues, err := d.collectValues(mi, ind, true, false, tz) if err != nil { return 0, err } @@ -269,7 +287,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e return 0, nil } -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { return 0, ErrMissPK @@ -293,7 +311,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ind.Field(mi.fields.pk.fieldIndex).SetInt(0) } - err := d.deleteRels(q, mi, []interface{}{pkValue}) + err := d.deleteRels(q, mi, []interface{}{pkValue}, tz) if err != nil { return num, err } @@ -306,7 +324,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e return 0, nil } -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) { +func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -327,7 +345,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con tables.parseRelated(qs.related, qs.relDepth) } - where, args := tables.getCondSql(cond, false) + where, args := tables.getCondSql(cond, false, tz) values = append(values, args...) @@ -356,13 +374,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error { +func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case od_CASCADE: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond) + _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -372,7 +390,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro if fi.onDelete == od_SET_DEFAULT { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params) + _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -382,7 +400,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro return nil } -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) { +func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) if qs != nil { tables.parseRelated(qs.related, qs.relDepth) @@ -394,7 +412,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con Q := d.ins.TableQuote() - where, args := tables.getCondSql(cond, false) + where, args := tables.getCondSql(cond, false, tz) join := tables.getJoinSql() cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) @@ -425,7 +443,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } - sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args) + marks := make([]string, len(args)) + for i, _ := range marks { + marks[i] = "?" + } + sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) d.ins.ReplaceMarks(&query) @@ -437,7 +459,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } if num > 0 { - err := d.deleteRels(q, mi, args) + err := d.deleteRels(q, mi, args, tz) if err != nil { return num, err } @@ -451,7 +473,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) { +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -490,7 +512,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) - where, args := tables.getCondSql(cond, false) + where, args := tables.getCondSql(cond, false, tz) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(mi, offset, rlimit) join := tables.getJoinSql() @@ -539,7 +561,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cacheM := make(map[string]*modelInfo) trefs := refs - d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)]) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz) trefs = refs[len(mi.fields.dbcols):] for _, tbl := range tables.tables { @@ -558,7 +580,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi mmi := fi.relModelInfo field := reflect.Indirect(last.Field(fi.fieldIndex)) if field.IsValid() { - d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)]) + d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) for _, fi := range mmi.fields.fieldsReverse { if fi.reverseFieldInfo.mi == lastm { if fi.reverseFieldInfo != nil { @@ -592,11 +614,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return cnt, nil } -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) { +func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) - where, args := tables.getCondSql(cond, false) + where, args := tables.getCondSql(cond, false, tz) tables.getOrderSql(qs.orders) join := tables.getJoinSql() @@ -612,9 +634,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } -func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) { +func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { sql := "" - params := getFlatParams(fi, args) + params := getFlatParams(fi, args, tz) if len(params) == 0 { panic(fmt.Sprintf("operator `%s` need at least one args", operator)) @@ -665,11 +687,11 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri return sql, params } -func (d *dbBase) GenerateOperatorLeftCol(string, *string) { - +func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { + // default not use } -func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { +func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() @@ -677,12 +699,12 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, field := ind.Field(fi.fieldIndex) - value, err := d.getValue(fi, val) + value, err := d.convertValueFromDB(fi, val, tz) if err != nil { panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) } - _, err = d.setValue(fi, value, &field) + _, err = d.setFieldValue(fi, value, &field) if err != nil { panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) @@ -690,7 +712,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } } -func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { +func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil } @@ -739,29 +761,32 @@ setValue: } case fieldType == TypeDateField || fieldType == TypeDateTimeField: if str == nil { - switch v := val.(type) { + switch t := val.(type) { case time.Time: - value = v + d.ins.TimeFromDB(&t, tz) + value = t default: - s := StrTo(ToStr(v)) + s := StrTo(ToStr(t)) str = &s } } if str != nil { s := str.String() - var format string + var ( + t time.Time + err error + ) if fi.fieldType == TypeDateField { - format = format_Date if len(s) > 10 { s = s[:10] } + t, err = time.ParseInLocation(format_Date, s, DefaultTimeLoc) } else { - format = format_DateTime if len(s) > 19 { s = s[:19] } + t, err = time.ParseInLocation(format_DateTime, s, tz) } - t, err := timeParse(s, format) if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { tErr = err goto end @@ -776,12 +801,16 @@ setValue: if str != nil { var err error switch fieldType { + case TypeBitField: + _, err = str.Int8() case TypeSmallIntegerField: _, err = str.Int16() case TypeIntegerField: _, err = str.Int32() case TypeBigIntegerField: _, err = str.Int64() + case TypePostiveBitField: + _, err = str.Uint8() case TypePositiveSmallIntegerField: _, err = str.Uint16() case TypePositiveIntegerField: @@ -835,7 +864,7 @@ end: } -func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) { +func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) { fieldType := fi.fieldType isNative := fi.isFielder == false @@ -909,7 +938,7 @@ setValue: return value, nil } -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) { +func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params @@ -960,7 +989,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond } } - where, args := tables.getCondSql(cond, false) + where, args := tables.getCondSql(cond, false, tz) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(mi, qs.offset, qs.limit) join := tables.getJoinSql() @@ -1007,7 +1036,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - value, err := d.getValue(fi, val) + value, err := d.convertValueFromDB(fi, val, tz) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } @@ -1022,7 +1051,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - value, err := d.getValue(fi, val) + value, err := d.convertValueFromDB(fi, val, tz) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } @@ -1036,7 +1065,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - value, err := d.getValue(fi, val) + value, err := d.convertValueFromDB(fi, val, tz) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } @@ -1079,3 +1108,11 @@ func (d *dbBase) ReplaceMarks(query *string) { func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } + +func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} diff --git a/orm/db_alias.go b/orm/db_alias.go index 2670d1f8..a4615129 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "sync" + "time" ) const defaultMaxIdle = 30 @@ -82,6 +83,7 @@ type alias struct { MaxIdle int DB *sql.DB DbBaser dbBaser + TZ *time.Location } func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { @@ -120,6 +122,33 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { al.DB.SetMaxIdleConns(al.MaxIdle) + // orm timezone system match database + // default use Local + al.TZ = time.Local + + switch al.Driver { + case DR_MySQL: + row := al.DB.QueryRow("SELECT @@session.time_zone") + var tz string + row.Scan(&tz) + if tz != "SYSTEM" { + t, err := time.Parse("-07:00", tz) + if err == nil { + al.TZ = t.Location() + } + } + case DR_Sqlite: + al.TZ = time.UTC + case DR_Postgres: + row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") + var tz string + row.Scan(&tz) + loc, err := time.LoadLocation(tz) + if err == nil { + al.TZ = loc + } + } + err = al.DB.Ping() if err != nil { err = fmt.Errorf("register db `%s`, %s", name, err.Error()) @@ -133,13 +162,22 @@ end: } } -func RegisterDriver(name string, typ DriverType) { - if t, ok := drivers[name]; ok == false { - drivers[name] = typ +func RegisterDriver(driverName string, typ DriverType) { + if t, ok := drivers[driverName]; ok == false { + drivers[driverName] = typ } else { if t != typ { - fmt.Println("name `%s` db driver already registered and is other type") + fmt.Println("driverName `%s` db driver already registered and is other type") os.Exit(2) } } } + +func SetDataBaseTZ(name string, tz *time.Location) { + if al, ok := dataBaseCache.get(name); ok { + al.TZ = tz + } else { + err := fmt.Errorf("DataBase name `%s` not registered", name) + fmt.Println(err) + } +} diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 4bfc5a83..729e6aab 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -30,7 +30,7 @@ func (d *dbBasePostgres) OperatorSql(operator string) string { return postgresOperators[operator] } -func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) { +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { switch operator { case "contains", "startswith", "endswith": *leftCol = fmt.Sprintf("%s::text", *leftCol) diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index e5a234e5..efea1967 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -1,5 +1,9 @@ package orm +import ( + "fmt" +) + var sqliteOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ? ESCAPE '\\'", @@ -25,6 +29,12 @@ func (d *dbBaseSqlite) OperatorSql(operator string) string { return sqliteOperators[operator] } +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + if fi.fieldType == TypeDateField { + *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) + } +} + func (d *dbBaseSqlite) SupportUpdateJoin() bool { return false } diff --git a/orm/db_tables.go b/orm/db_tables.go index 5cf567cc..afed3c96 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -3,6 +3,7 @@ package orm import ( "fmt" "strings" + "time" ) type dbTable struct { @@ -266,7 +267,7 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string return } -func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { +func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { if cond == nil || cond.IsEmpty() { return } @@ -288,7 +289,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ where += "NOT " } if p.isCond { - w, ps := d.getCondSql(p.cond, true) + w, ps := d.getCondSql(p.cond, true, tz) if w != "" { w = fmt.Sprintf("( %s) ", w) } @@ -313,10 +314,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ operator = "exact" } - operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args) + operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) - d.base.GenerateOperatorLeftCol(operator, &leftCol) + d.base.GenerateOperatorLeftCol(fi, operator, &leftCol) where += fmt.Sprintf("%s %s ", leftCol, operSql) params = append(params, args...) diff --git a/orm/db_utils.go b/orm/db_utils.go index a31c5221..0773f457 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -24,7 +24,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac return } -func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) { +func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { outFor: for _, arg := range args { @@ -39,9 +39,9 @@ outFor: case []byte: case time.Time: if fi != nil && fi.fieldType == TypeDateField { - arg = v.Format(format_Date) + arg = v.In(DefaultTimeLoc).Format(format_Date) } else { - arg = v.Format(format_DateTime) + arg = v.In(tz).Format(format_DateTime) } default: kind := val.Kind() @@ -65,7 +65,7 @@ outFor: } if len(args) > 0 { - p := getFlatParams(fi, args) + p := getFlatParams(fi, args, tz) params = append(params, p...) } continue outFor diff --git a/orm/models_fields.go b/orm/models_fields.go index 0f09c1e9..269e8580 100644 --- a/orm/models_fields.go +++ b/orm/models_fields.go @@ -22,12 +22,16 @@ const ( // time.Time TypeDateTimeField + // int8 + TypeBitField // int16 TypeSmallIntegerField // int32 TypeIntegerField // int64 TypeBigIntegerField + // uint8 + TypePostiveBitField // uint16 TypePositiveSmallIntegerField // uint32 @@ -49,8 +53,8 @@ const ( const ( IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 - IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8 - IsRelField = ^-RelReverseMany >> 12 << 13 + IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 + IsRelField = ^-RelReverseMany >> 14 << 15 IsFieldType = ^-RelReverseMany<<1 + 1 ) diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 0cbbdf65..517179f4 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -327,8 +327,8 @@ checkType: case TypeDecimalField: d1 := digits d2 := decimals - v1, er1 := StrTo(d1).Int16() - v2, er2 := StrTo(d2).Int16() + v1, er1 := StrTo(d1).Int8() + v2, er2 := StrTo(d2).Int8() if er1 != nil || er2 != nil { err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) goto end @@ -383,12 +383,16 @@ checkType: _, err = v.Bool() case TypeFloatField, TypeDecimalField: _, err = v.Float64() + case TypeBitField: + _, err = v.Int8() case TypeSmallIntegerField: _, err = v.Int16() case TypeIntegerField: _, err = v.Int32() case TypeBigIntegerField: _, err = v.Int64() + case TypePostiveBitField: + _, err = v.Uint8() case TypePositiveSmallIntegerField: _, err = v.Uint16() case TypePositiveIntegerField: diff --git a/orm/models_test.go b/orm/models_test.go index eca35553..80923603 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -6,11 +6,60 @@ import ( "strings" "time" + // _ "github.com/bylevel/pq" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) +type Data struct { + Id int `orm:"auto"` + Boolean bool + Char string `orm:"size(50)"` + Text string + Date time.Time `orm:"type(date)"` + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + Id int `orm:"auto"` + Boolean bool `orm:"null"` + Char string `orm:"size(50);null"` + Text string `orm:"null"` + Date time.Time `orm:"type(date);null"` + DateTime time.Time `orm:"null"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` +} + type User struct { Id int `orm:"auto"` UserName string `orm:"size(30);unique"` @@ -111,6 +160,8 @@ var initSQLs = map[string]string{ "DROP TABLE IF EXISTS `tag`;\n" + "DROP TABLE IF EXISTS `post_tags`;\n" + "DROP TABLE IF EXISTS `comment`;\n" + + "DROP TABLE IF EXISTS `data`;\n" + + "DROP TABLE IF EXISTS `data_null`;\n" + "CREATE TABLE `user_profile` (\n" + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + " `age` smallint NOT NULL,\n" + @@ -153,6 +204,52 @@ var initSQLs = map[string]string{ " `parent_id` integer,\n" + " `created` datetime NOT NULL\n" + ") ENGINE=INNODB;\n" + + "CREATE TABLE `data` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `boolean` bool NOT NULL,\n" + + " `char` varchar(50) NOT NULL,\n" + + " `text` longtext NOT NULL,\n" + + " `date` date NOT NULL,\n" + + " `date_time` datetime NOT NULL,\n" + + " `byte` tinyint unsigned NOT NULL,\n" + + " `rune` integer NOT NULL,\n" + + " `int` integer NOT NULL,\n" + + " `int8` tinyint NOT NULL,\n" + + " `int16` smallint NOT NULL,\n" + + " `int32` integer NOT NULL,\n" + + " `int64` bigint NOT NULL,\n" + + " `uint` integer unsigned NOT NULL,\n" + + " `uint8` tinyint unsigned NULL,\n" + + " `uint16` smallint unsigned NOT NULL,\n" + + " `uint32` integer unsigned NOT NULL,\n" + + " `uint64` bigint unsigned NOT NULL,\n" + + " `float32` double precision NOT NULL,\n" + + " `float64` double precision NOT NULL,\n" + + " `decimal` numeric(8,4) NOT NULL\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `data_null` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `boolean` bool,\n" + + " `char` varchar(50),\n" + + " `text` longtext,\n" + + " `date` date,\n" + + " `date_time` datetime,\n" + + " `byte` tinyint unsigned,\n" + + " `rune` integer,\n" + + " `int` integer,\n" + + " `int8` tinyint,\n" + + " `int16` smallint,\n" + + " `int32` integer,\n" + + " `int64` bigint,\n" + + " `uint` integer unsigned,\n" + + " `uint8` tinyint unsigned,\n" + + " `uint16` smallint unsigned,\n" + + " `uint32` integer unsigned,\n" + + " `uint64` bigint unsigned,\n" + + " `float32` double precision,\n" + + " `float64` double precision,\n" + + " `decimal` numeric(8,4)\n" + + ") ENGINE=INNODB;\n" + "CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" + "CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" + "CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" + @@ -165,6 +262,8 @@ DROP TABLE IF EXISTS "post"; DROP TABLE IF EXISTS "tag"; DROP TABLE IF EXISTS "post_tags"; DROP TABLE IF EXISTS "comment"; +DROP TABLE IF EXISTS "data"; +DROP TABLE IF EXISTS "data_null"; CREATE TABLE "user_profile" ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "age" smallint NOT NULL, @@ -207,6 +306,52 @@ CREATE TABLE "comment" ( "parent_id" integer, "created" datetime NOT NULL ); +CREATE TABLE "data" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "boolean" bool NOT NULL, + "char" varchar(50) NOT NULL, + "text" text NOT NULL, + "date" date NOT NULL, + "date_time" datetime NOT NULL, + "byte" tinyint unsigned NOT NULL, + "rune" integer NOT NULL, + "int" integer NOT NULL, + "int8" tinyint NOT NULL, + "int16" smallint NOT NULL, + "int32" integer NOT NULL, + "int64" bigint NOT NULL, + "uint" integer unsigned NOT NULL, + "uint8" tinyint unsigned NOT NULL, + "uint16" smallint unsigned NOT NULL, + "uint32" integer unsigned NOT NULL, + "uint64" bigint unsigned NOT NULL, + "float32" real NOT NULL, + "float64" real NOT NULL, + "decimal" decimal +); +CREATE TABLE "data_null" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "boolean" bool, + "char" varchar(50), + "text" text, + "date" date, + "date_time" datetime, + "byte" tinyint unsigned, + "rune" integer, + "int" integer, + "int8" tinyint, + "int16" smallint, + "int32" integer, + "int64" bigint, + "uint" integer unsigned, + "uint8" tinyint unsigned, + "uint16" smallint unsigned, + "uint32" integer unsigned, + "uint64" bigint unsigned, + "float32" real, + "float64" real, + "decimal" decimal +); CREATE INDEX "user_141c6eec" ON "user" ("profile_id"); CREATE INDEX "post_fbfc09f1" ON "post" ("user_id"); CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id"); @@ -220,6 +365,8 @@ DROP TABLE IF EXISTS "post"; DROP TABLE IF EXISTS "tag"; DROP TABLE IF EXISTS "post_tags"; DROP TABLE IF EXISTS "comment"; +DROP TABLE IF EXISTS "data"; +DROP TABLE IF EXISTS "data_null"; CREATE TABLE "user_profile" ( "id" serial NOT NULL PRIMARY KEY, "age" smallint NOT NULL, @@ -262,6 +409,52 @@ CREATE TABLE "comment" ( "parent_id" integer, "created" timestamp with time zone NOT NULL ); +CREATE TABLE "data" ( + "id" serial NOT NULL PRIMARY KEY, + "boolean" bool NOT NULL, + "char" varchar(50) NOT NULL, + "text" text NOT NULL, + "date" date NOT NULL, + "date_time" timestamp with time zone NOT NULL, + "byte" smallint CHECK("byte" >= 0 AND "byte" <= 255) NOT NULL, + "rune" integer NOT NULL, + "int" integer NOT NULL, + "int8" smallint CHECK("int8" >= -127 AND "int8" <= 128) NOT NULL, + "int16" smallint NOT NULL, + "int32" integer NOT NULL, + "int64" bigint NOT NULL, + "uint" bigint CHECK("uint" >= 0) NOT NULL, + "uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255) NOT NULL, + "uint16" integer CHECK("uint16" >= 0) NOT NULL, + "uint32" bigint CHECK("uint32" >= 0) NOT NULL, + "uint64" bigint CHECK("uint64" >= 0) NOT NULL, + "float32" double precision NOT NULL, + "float64" double precision NOT NULL, + "decimal" numeric(8, 4) +); +CREATE TABLE "data_null" ( + "id" serial NOT NULL PRIMARY KEY, + "boolean" bool, + "char" varchar(50), + "text" text, + "date" date, + "date_time" timestamp with time zone, + "byte" smallint CHECK("byte" >= 0 AND "byte" <= 255), + "rune" integer, + "int" integer, + "int8" smallint CHECK("int8" >= -127 AND "int8" <= 128), + "int16" smallint, + "int32" integer, + "int64" bigint, + "uint" bigint CHECK("uint" >= 0), + "uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255), + "uint16" integer CHECK("uint16" >= 0), + "uint32" bigint CHECK("uint32" >= 0), + "uint64" bigint CHECK("uint64" >= 0), + "float32" double precision, + "float64" double precision, + "decimal" numeric(8, 4) +); CREATE INDEX "user_profile_id" ON "user" ("profile_id"); CREATE INDEX "post_user_id" ON "post" ("user_id"); CREATE INDEX "comment_post_id" ON "comment" ("post_id"); @@ -269,6 +462,10 @@ CREATE INDEX "comment_parent_id" ON "comment" ("parent_id"); `} func init() { + // err := os.Setenv("TZ", "+00:00") + // fmt.Println(err) + + RegisterModel(new(Data), new(DataNull)) RegisterModel(new(User)) RegisterModel(new(Profile)) RegisterModel(new(Post)) diff --git a/orm/models_utils.go b/orm/models_utils.go index d15447a1..1efaad97 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -43,15 +43,19 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) switch elm.Kind() { + case reflect.Int8: + ft = TypeBitField case reflect.Int16: ft = TypeSmallIntegerField case reflect.Int32, reflect.Int: ft = TypeIntegerField case reflect.Int64: ft = TypeBigIntegerField + case reflect.Uint8: + ft = TypePostiveBitField case reflect.Uint16: ft = TypePositiveSmallIntegerField - case reflect.Uint32: + case reflect.Uint32, reflect.Uint: ft = TypePositiveIntegerField case reflect.Uint64: ft = TypePositiveBigIntegerField diff --git a/orm/orm.go b/orm/orm.go index 868dde58..6be35ae9 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -55,7 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { func (o *orm) Read(md interface{}) error { mi, ind := o.getMiInd(md) - err := o.alias.DbBaser.Read(o.db, mi, ind) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ) if err != nil { return err } @@ -64,7 +64,7 @@ func (o *orm) Read(md interface{}) error { func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -78,7 +78,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Update(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) - num, err := o.alias.DbBaser.Update(o.db, mi, ind) + num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ) if err != nil { return num, err } @@ -87,7 +87,7 @@ func (o *orm) Update(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind) + num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) if err != nil { return num, err } diff --git a/orm/orm_object.go b/orm/orm_object.go index 819a18bc..ee6566fc 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -28,7 +28,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Sprintf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind) + id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 2f8c270f..9fe62ef3 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -77,15 +77,15 @@ func (o querySet) SetCond(cond *Condition) QuerySeter { } func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond) + return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values) + return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond) + return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } func (o *querySet) PrepareInsert() (Inserter, error) { @@ -93,11 +93,11 @@ func (o *querySet) PrepareInsert() (Inserter, error) { } func (o *querySet) All(container interface{}) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) } func (o *querySet) One(container interface{}) error { - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) if err != nil { return err } @@ -111,15 +111,15 @@ func (o *querySet) One(container interface{}) error { } func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results) + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results) + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result) + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 669a354f..abf39b0f 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -60,7 +60,7 @@ func (o *rawSet) Exec() (sql.Result, error) { query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args) + args := getFlatParams(nil, o.args, o.orm.alias.TZ) return o.orm.db.Exec(query, args...) } @@ -96,7 +96,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args) + args := getFlatParams(nil, o.args, o.orm.alias.TZ) var rs *sql.Rows if r, err := o.orm.db.Query(query, args...); err != nil { diff --git a/orm/orm_test.go b/orm/orm_test.go index ad804114..cd3289e3 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -15,6 +15,11 @@ import ( var _ = os.PathSeparator +var ( + test_Date = format_Date + " -0700" + test_DateTime = format_DateTime + " -0700" +) + type T_Code int const ( @@ -141,7 +146,7 @@ func getCaller(skip int) string { if cur == line { flag = ">>" } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i]))) + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) if code != "" { codes = append(codes, code) } @@ -158,7 +163,11 @@ func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) if len(args) > 0 { - con += fmt.Sprint(args...) + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") } t.Error(con) t.Fail() @@ -169,7 +178,11 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) if len(args) > 0 { - con += fmt.Sprint(args...) + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") } t.Error(con) t.FailNow() @@ -177,13 +190,100 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) { } func TestModelSyntax(t *testing.T) { - mi, ok := modelCache.get("user") + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFN(fn) + throwFail(t, AssertIs(ok, T_Equal, true)) + + mi, ok = modelCache.get("user") throwFail(t, AssertIs(ok, T_Equal, true)) if ok { throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true)) } } +func TestDataTypes(t *testing.T) { + values := map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), + } + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range values { + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 1)) + + d = Data{Id: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range values { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + } + throwFail(t, AssertIs(vu == value, T_Equal, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 1)) + + d = DataNull{Id: 1} + err = dORM.Read(&d) + throwFail(t, err) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{Id: 2} + err = dORM.Read(&d) + throwFail(t, err) +} + func TestCRUD(t *testing.T) { profile := NewProfile() profile.Age = 30 @@ -214,8 +314,8 @@ func TestCRUD(t *testing.T) { throwFail(t, AssertIs(u.Status, T_Equal, 3)) throwFail(t, AssertIs(u.IsStaff, T_Equal, true)) throwFail(t, AssertIs(u.IsActive, T_Equal, true)) - throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date)) - throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), T_Equal, user.Created.In(DefaultTimeLoc), test_Date)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), T_Equal, user.Updated.In(DefaultTimeLoc), test_DateTime)) user.UserName = "astaxie" user.Profile = profile @@ -360,7 +460,9 @@ The program—and web server—godoc processes Go source files to extract docume } func TestExpr(t *testing.T) { - qs := dORM.QueryTable("User") + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable("User") qs = dORM.QueryTable("user") num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) @@ -369,6 +471,10 @@ func TestExpr(t *testing.T) { num, err = qs.Filter("created", time.Now()).Count() throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 3)) + + num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) } func TestOperators(t *testing.T) { @@ -820,9 +926,11 @@ func TestRaw(t *testing.T) { res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() throwFail(t, err) - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 3)) + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) + } } } } diff --git a/orm/types.go b/orm/types.go index 5c6538f5..44194bd4 100644 --- a/orm/types.go +++ b/orm/types.go @@ -3,6 +3,7 @@ package orm import ( "database/sql" "reflect" + "time" ) type Driver interface { @@ -110,23 +111,25 @@ type txEnder interface { } type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value) error - Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error) SupportUpdateJoin() bool - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) + UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) OperatorSql(string) string - GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{}) - GenerateOperatorLeftCol(string, *string) + GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*fieldInfo, string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) HasReturningID(*modelInfo, *string) bool + TimeFromDB(*time.Time, *time.Location) + TimeToDB(*time.Time, *time.Location) } diff --git a/orm/utils.go b/orm/utils.go index 7f6373a4..b3e827e5 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) { return strconv.ParseFloat(f.String(), 64) } +func (f StrTo) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + func (f StrTo) Int16() (int16, error) { v, err := strconv.ParseInt(f.String(), 10, 16) return int16(v), err @@ -53,6 +58,11 @@ func (f StrTo) Int64() (int64, error) { return int64(v), err } +func (f StrTo) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + func (f StrTo) Uint16() (uint16, error) { v, err := strconv.ParseUint(f.String(), 10, 16) return uint16(v), err @@ -85,6 +95,8 @@ func ToStr(value interface{}, args ...int) (s string) { s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) case int: s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) case int16: s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) case int32: @@ -93,6 +105,8 @@ func ToStr(value interface{}, args ...int) (s string) { s = strconv.FormatInt(v, argInt(args).Get(0, 10)) case uint: s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) case uint16: s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) case uint32: