1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 04:01:29 +00:00

golint orm

This commit is contained in:
astaxie 2015-09-12 21:46:43 +08:00
parent 542e143e55
commit 68ec133aa8
25 changed files with 574 additions and 501 deletions

View File

@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2) os.Exit(2)
} }
// listen for orm command and then run it if command arguments passed. // RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() { func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" { if len(os.Args) < 2 || os.Args[1] != "orm" {
return return
@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
drops = getDbDropSql(d.al) drops = getDbDropSQL(d.al)
} }
db := d.al.DB db := d.al.DB
@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
} }
} }
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db) tables, err := d.al.DbBaser.GetTables(db)
if err != nil { if err != nil {
@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
} }
query := idx.Sql query := idx.SQL
_, err := db.Exec(query) _, err := db.Exec(query)
if d.verbose { if d.verbose {
fmt.Printf(" %s\n", query) fmt.Printf(" %s\n", query)
@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]} queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql) queries = append(queries, idx.SQL)
} }
for _, query := range queries { for _, query := range queries {
@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
} }
// database creation commander interface implement. // database creation commander interface implement.
type commandSqlAll struct { type commandSQLAll struct {
al *alias al *alias
} }
// parse orm command line arguments. // parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) { func (d *commandSQLAll) Parse(args []string) {
var name string var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
} }
// run orm line command. // run orm line command.
func (d *commandSqlAll) Run() error { func (d *commandSQLAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSQL(d.al)
var all []string var all []string
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]} queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql) queries = append(queries, idx.SQL)
} }
sql := strings.Join(queries, "\n") sql := strings.Join(queries, "\n")
all = append(all, sql) all = append(all, sql)
@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() { func init() {
commands["syncdb"] = new(commandSyncDb) commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSQLAll)
} }
// run syncdb command line. // RunSyncdb run syncdb command line.
// name means table's alias name. default is "default". // name means table's alias name. default is "default".
// force means run next sql if the current is error. // force means run next sql if the current is error.
// verbose means show all info when running command or not. // verbose means show all info when running command or not.

View File

@ -23,11 +23,11 @@ import (
type dbIndex struct { type dbIndex struct {
Table string Table string
Name string Name string
Sql string SQL string
} }
// create database drop sql. // create database drop sql.
func getDbDropSql(al *alias) (sqls []string) { func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
os.Exit(2) os.Exit(2)
@ -65,7 +65,7 @@ checkColumn:
case TypeIntegerField: case TypeIntegerField:
col = T["int32"] col = T["int32"]
case TypeBigIntegerField: case TypeBigIntegerField:
if al.Driver == DR_Sqlite { if al.Driver == DRSqlite {
fieldType = TypeIntegerField fieldType = TypeIntegerField
goto checkColumn goto checkColumn
} }
@ -112,7 +112,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
} }
// create database creation string. // create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
os.Exit(2) os.Exit(2)
@ -142,7 +142,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto { if fi.auto {
switch al.Driver { switch al.Driver {
case DR_Sqlite, DR_Postgres: case DRSqlite, DRPostgres:
column += T["auto"] column += T["auto"]
default: default:
column += col + " " + T["auto"] column += col + " " + T["auto"]
@ -201,7 +201,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n") sql += strings.Join(columns, ",\n")
sql += "\n)" sql += "\n)"
if al.Driver == DR_MySQL { if al.Driver == DRMySQL {
var engine string var engine string
if mi.model != nil { if mi.model != nil {
engine = getTableEngine(mi.addrField) engine = getTableEngine(mi.addrField)
@ -237,7 +237,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{} index := dbIndex{}
index.Table = mi.table index.Table = mi.table
index.Name = name index.Name = name
index.Sql = sql index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index) tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
} }
@ -247,7 +247,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return return
} }
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string { func getColumnDefault(fi *fieldInfo) string {
var ( var (
@ -264,7 +263,7 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
return v; return v
case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField, case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,

152
orm/db.go
View File

@ -24,12 +24,13 @@ import (
) )
const ( const (
format_Date = "2006-01-02" formatDate = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05" formatDateTime = "2006-01-02 15:04:05"
) )
var ( var (
ErrMissPK = errors.New("missed pk value") // missing pk error // ErrMissPK missing pk error
ErrMissPK = errors.New("missed pk value")
) )
var ( var (
@ -216,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
} }
if fi.null == false && value == nil { if fi.null == false && value == nil {
return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
} }
} }
} }
} }
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert { if fi.autoNow || fi.autoNowAdd && insert {
if insert { if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() { if t, ok := value.(time.Time); ok && !t.IsZero() {
break break
@ -282,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64 var id int64
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err
} else { }
if res, err := stmt.Exec(values...); err == nil { res, err := stmt.Exec(values...)
if err == nil {
return res.LastInsertId() return res.LastInsertId()
} else { }
return 0, err return 0, err
}
}
} }
// query sql ,read records and persist in dbBaser. // query sql ,read records and persist in dbBaser.
@ -339,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows return ErrNoRows
} }
return err return err
} else { }
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind) ind.Set(mind)
}
return nil return nil
} }
@ -444,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
if res, err := q.Exec(query, values...); err == nil { res, err := q.Exec(query, values...)
if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
} }
return res.LastInsertId() return res.LastInsertId()
} else { }
return 0, err return 0, err
} }
} else {
row := q.QueryRow(query, values...) row := q.QueryRow(query, values...)
var id int64 var id int64
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err
}
} }
// execute update sql dbQuerier with given struct reflect.Value. // execute update sql dbQuerier with given struct reflect.Value.
@ -495,9 +490,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
if res, err := q.Exec(query, setValues...); err == nil { if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected() return res.RowsAffected()
} else {
return 0, err
} }
return 0, err
} }
// execute delete sql dbQuerier with given struct reflect.Value. // execute delete sql dbQuerier with given struct reflect.Value.
@ -513,14 +507,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q) query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, pkValue)
if res, err := q.Exec(query, pkValue); err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@ -529,17 +521,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
} }
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz) err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
} }
return num, err return num, err
} else {
return 0, err
} }
return 0, err
} }
// update table-related record by querySet. // update table-related record by querySet.
@ -565,11 +554,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...) values = append(values, args...)
join := tables.getJoinSql() join := tables.getJoinSQL()
var query, T string var query, T string
@ -585,13 +574,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok { if c, ok := values[i].(colValue); ok {
switch c.opt { switch c.opt {
case Col_Add: case ColAdd:
cols = append(cols, col+" = "+col+" + ?") cols = append(cols, col+" = "+col+" + ?")
case Col_Minus: case ColMinus:
cols = append(cols, col+" = "+col+" - ?") cols = append(cols, col+" = "+col+" - ?")
case Col_Multiply: case ColMultiply:
cols = append(cols, col+" = "+col+" * ?") cols = append(cols, col+" = "+col+" * ?")
case Col_Except: case ColExcept:
cols = append(cols, col+" = "+col+" / ?") cols = append(cols, col+" = "+col+" / ?")
} }
values[i] = c.value values[i] = c.value
@ -610,12 +599,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...)
if res, err := q.Exec(query, values...); err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} else {
return 0, err
} }
return 0, err
} }
// delete related records. // delete related records.
@ -624,23 +612,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
switch fi.onDelete { switch fi.onDelete {
case od_CASCADE: case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil { if err != nil {
return err return err
} }
case od_SET_DEFAULT, od_SET_NULL: case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil} params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT { if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String() params[fi.column] = fi.initial.String()
} }
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil { if err != nil {
return err return err
} }
case od_DO_NOTHING: case odDoNothing:
} }
} }
return nil return nil
@ -661,8 +649,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
join := tables.getJoinSql() join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@ -670,16 +658,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil { r, err := q.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
rs = r
defer rs.Close() defer rs.Close()
var ref interface{} var ref interface{}
args = make([]interface{}, 0) args = make([]interface{}, 0)
cnt := 0 cnt := 0
for rs.Next() { for rs.Next() {
@ -702,24 +688,21 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
if res, err := q.Exec(query, args...); err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if num > 0 { if num > 0 {
err := d.deleteRels(q, mi, args, tz) err := d.deleteRels(q, mi, args, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
} }
return num, nil return num, nil
} else {
return 0, err
} }
return 0, err
} }
// read related records. // read related records.
@ -801,11 +784,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSql(qs.groups) groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSql(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSql(mi, offset, rlimit) limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSql() join := tables.getJoinSQL()
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
if tbl.sel { if tbl.sel {
@ -824,11 +807,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil { r, err := q.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
rs = r
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
@ -942,9 +925,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
tables.getOrderSql(qs.orders) tables.getOrderSQL(qs.orders)
join := tables.getJoinSql() join := tables.getJoinSQL()
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
@ -959,7 +942,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
} }
// generate sql with replacing operator string placeholders and replaced values. // generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -984,7 +967,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 { if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
} }
sql = d.ins.OperatorSql(operator) sql = d.ins.OperatorSQL(operator)
switch operator { switch operator {
case "exact": case "exact":
if arg == nil { if arg == nil {
@ -1112,12 +1095,12 @@ setValue:
) )
if len(s) >= 19 { if len(s) >= 19 {
s = s[:19] s = s[:19]
t, err = time.ParseInLocation(format_DateTime, s, tz) t, err = time.ParseInLocation(formatDateTime, s, tz)
} else { } else {
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(format_Date, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} }
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
@ -1448,25 +1431,22 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
} }
} }
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSql(qs.groups) groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSql(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSql(mi, qs.offset, qs.limit) limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSql() join := tables.getJoinSQL()
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy,orderBy, limit) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows rs, err := q.Query(query, args...)
if r, err := q.Query(query, args...); err != nil { if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
refs := make([]interface{}, len(cols)) refs := make([]interface{}, len(cols))
for i := range refs { for i := range refs {
var ref interface{} var ref interface{}
@ -1481,11 +1461,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
) )
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if cols, err := rs.Columns(); err != nil { cols, err := rs.Columns()
if err != nil {
return 0, err return 0, err
} else {
columns = cols
} }
columns = cols
} }
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(refs...); err != nil {
@ -1649,7 +1629,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
} }
// not implement. // not implement.
func (d *dbBase) OperatorSql(operator string) string { func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -22,15 +22,16 @@ import (
"time" "time"
) )
// database driver constant int. // DriverType database driver constant int.
type DriverType int type DriverType int
// Enum the Database driver
const ( const (
_ DriverType = iota // int enum type _ DriverType = iota // int enum type
DR_MySQL // mysql DRMySQL // mysql
DR_Sqlite // sqlite DRSqlite // sqlite
DR_Oracle // oracle DROracle // oracle
DR_Postgres // pgsql DRPostgres // pgsql
) )
// database driver string. // database driver string.
@ -53,15 +54,15 @@ var _ Driver = new(driver)
var ( var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)} dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{ drivers = map[string]DriverType{
"mysql": DR_MySQL, "mysql": DRMySQL,
"postgres": DR_Postgres, "postgres": DRPostgres,
"sqlite3": DR_Sqlite, "sqlite3": DRSqlite,
} }
dbBasers = map[DriverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(), DRMySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(), DRSqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(), DROracle: newdbBaseMysql(),
DR_Postgres: newdbBasePostgres(), DRPostgres: newdbBasePostgres(),
} }
) )
@ -119,7 +120,7 @@ func detectTZ(al *alias) {
} }
switch al.Driver { switch al.Driver {
case DR_MySQL: case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
@ -147,10 +148,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB" al.Engine = "INNODB"
} }
case DR_Sqlite: case DRSqlite:
al.TZ = time.UTC al.TZ = time.UTC
case DR_Postgres: case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
@ -188,12 +189,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil return al, nil
} }
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db) _, err := addAliasWthDB(aliasName, driverName, db)
return err return err
} }
// Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var ( var (
err error err error
@ -236,7 +238,7 @@ end:
return err return err
} }
// Register a database driver use specify driver name, this can be definition the driver is which database type. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error { func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false { if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ drivers[driverName] = typ
@ -248,7 +250,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil return nil
} }
// Change the database default used timezone // SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error { func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
@ -258,14 +260,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil return nil
} }
// Change the max idle conns for *sql.DB, use specify database alias name // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) { func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns) al.DB.SetMaxIdleConns(maxIdleConns)
} }
// Change the max open conns for *sql.DB, use specify database alias name // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) { func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns al.MaxOpenConns = maxOpenConns
@ -275,7 +277,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
} }
} }
// Get *sql.DB from registered database by db alias name. // GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set. // Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) { func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string var name string
@ -284,9 +286,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else { } else {
name = "default" name = "default"
} }
if al, ok := dataBaseCache.get(name); ok { al, ok := dataBaseCache.get(name)
if ok {
return al.DB, nil return al.DB, nil
} else {
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
} }

View File

@ -67,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator. // get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string { func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }

View File

@ -66,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres) var _ dbBaser = new(dbBasePostgres)
// get postgresql operator. // get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string { func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
@ -101,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0 num := 0
for _, c := range q { for _, c := range q {
if c == '?' { if c == '?' {
num += 1 num++
} }
} }
if num == 0 { if num == 0 {
@ -114,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' { if c == '?' {
data = append(data, '$') data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...) data = append(data, []byte(strconv.Itoa(num))...)
num += 1 num++
} else { } else {
data = append(data, c) data = append(data, c)
} }

View File

@ -66,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator. // get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string { func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }

View File

@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
// generate join string. // generate join string.
func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote() Q := t.base.TableQuote()
for _, jt := range t.tables { for _, jt := range t.tables {
@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
) )
num := len(exprs) - 1 num := len(exprs) - 1
names := make([]string, 0) var names []string
inner := true inner := true
@ -326,7 +326,7 @@ loopFor:
} }
// generate condition sql. // generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
} }
@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT " where += "NOT "
} }
if p.isCond { if p.isCond {
w, ps := t.getCondSql(p.cond, true, tz) w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" { if w != "" {
w = fmt.Sprintf("( %s) ", w) w = fmt.Sprintf("( %s) ", w)
} }
@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact" operator = "exact"
} }
operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz) operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql) where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...) params = append(params, args...)
} }
@ -391,7 +391,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
} }
// generate group sql. // generate group sql.
func (t *dbTables) getGroupSql(groups []string) (groupSql string) { func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 { if len(groups) == 0 {
return return
} }
@ -410,12 +410,12 @@ func (t *dbTables) getGroupSql(groups []string) (groupSql string) {
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
} }
groupSql = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return return
} }
// generate order sql. // generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) { func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 { if len(orders) == 0 {
return return
} }
@ -439,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
} }
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return return
} }
// generate limit sql. // generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
} }

View File

@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
} else {
panic(fmt.Errorf("unknown DataBase alias name %s", name))
} }
panic(fmt.Errorf("unknown DataBase alias name %s", name))
} }
// get pk column info. // get pk column info.
@ -80,19 +79,19 @@ outFor:
var err error var err error
if len(v) >= 19 { if len(v) >= 19 {
s := v[:19] s := v[:19]
t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc) t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else { } else {
s := v s := v
if len(v) > 10 { if len(v) > 10 {
s = v[:10] s = v[:10]
} }
t, err = time.ParseInLocation(format_Date, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} }
if err == nil { if err == nil {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
v = t.In(tz).Format(format_Date) v = t.In(tz).Format(formatDate)
} else { } else {
v = t.In(tz).Format(format_DateTime) v = t.In(tz).Format(formatDateTime)
} }
} }
} }
@ -137,9 +136,9 @@ outFor:
case reflect.Struct: case reflect.Struct:
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date) arg = v.In(tz).Format(formatDate)
} else { } else {
arg = v.In(tz).Format(format_DateTime) arg = v.In(tz).Format(formatDateTime)
} }
} else { } else {
typ := val.Type() typ := val.Type()

View File

@ -19,10 +19,10 @@ import (
) )
const ( const (
od_CASCADE = "cascade" odCascade = "cascade"
od_SET_NULL = "set_null" odSetNULL = "set_null"
od_SET_DEFAULT = "set_default" odSetDefault = "set_default"
od_DO_NOTHING = "do_nothing" odDoNothing = "do_nothing"
defaultStructTagName = "orm" defaultStructTagName = "orm"
defaultStructTagDelim = ";" defaultStructTagDelim = ";"
) )
@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false mc.done = false
} }
// Clean model cache. Then you can re-RegisterModel. // ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case. // Common use this api for test case.
func ResetModelCache() { func ResetModelCache() {
modelCache.clean() modelCache.clean()

View File

@ -51,12 +51,10 @@ func registerModel(prefix string, model interface{}) {
} }
info := newModelInfo(val) info := newModelInfo(val)
if info.fields.pk == nil { if info.fields.pk == nil {
outFor: outFor:
for _, fi := range info.fields.fieldsDB { for _, fi := range info.fields.fieldsDB {
if fi.name == "Id" { if strings.ToLower(fi.name) == "id" {
if fi.sf.Tag.Get(defaultStructTagName) == "" {
switch fi.addrValue.Elem().Kind() { switch fi.addrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true fi.auto = true
@ -66,7 +64,6 @@ func registerModel(prefix string, model interface{}) {
} }
} }
} }
}
if info.fields.pk == nil { if info.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name) fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name)
@ -298,12 +295,12 @@ end:
} }
} }
// register models // RegisterModel register models
func RegisterModel(models ...interface{}) { func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...) RegisterModelWithPrefix("", models...)
} }
// register models with a prefix // RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) { func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap")) panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -314,7 +311,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
} }
} }
// bootrap models. // BootStrap bootrap models.
// make all model parsed and can not add more models // make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {

View File

@ -15,49 +15,28 @@
package orm package orm
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
) )
// Define the Type enum
const ( const (
// bool
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
// string
TypeCharField TypeCharField
// string
TypeTextField TypeTextField
// time.Time
TypeDateField TypeDateField
// time.Time
TypeDateTimeField TypeDateTimeField
// int8
TypeBitField TypeBitField
// int16
TypeSmallIntegerField TypeSmallIntegerField
// int32
TypeIntegerField TypeIntegerField
// int64
TypeBigIntegerField TypeBigIntegerField
// uint8
TypePositiveBitField TypePositiveBitField
// uint16
TypePositiveSmallIntegerField TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField TypePositiveBigIntegerField
// float64
TypeFloatField TypeFloatField
// float64
TypeDecimalField TypeDecimalField
RelForeignKey RelForeignKey
RelOneToOne RelOneToOne
RelManyToMany RelManyToMany
@ -65,6 +44,7 @@ const (
RelReverseMany RelReverseMany
) )
// Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
// A true/false field. // BooleanField A true/false field.
type BooleanField bool type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool { func (e BooleanField) Value() bool {
return bool(e) return bool(e)
} }
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) { func (e *BooleanField) Set(d bool) {
*e = BooleanField(d) *e = BooleanField(d)
} }
// String format the Bool to string
func (e *BooleanField) String() string { func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value()) return strconv.FormatBool(e.Value())
} }
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int { func (e *BooleanField) FieldType() int {
return TypeBooleanField return TypeBooleanField
} }
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error { func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case bool: case bool:
@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} { func (e *BooleanField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField) var _ Fielder = new(BooleanField)
// A string field // CharField A string field
// required values tag: size // required values tag: size
// The size is enforced at the database level and in modelss validation. // The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"` // eg: `orm:"size(120)"`
type CharField string type CharField string
// Value return the CharField's Value
func (e CharField) Value() string { func (e CharField) Value() string {
return string(e) return string(e)
} }
// Set CharField value
func (e *CharField) Set(d string) { func (e *CharField) Set(d string) {
*e = CharField(d) *e = CharField(d)
} }
// String return the CharField
func (e *CharField) String() string { func (e *CharField) String() string {
return e.Value() return e.Value()
} }
// FieldType return the enum type
func (e *CharField) FieldType() int { func (e *CharField) FieldType() int {
return TypeCharField return TypeCharField
} }
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error { func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
e.Set(d) e.Set(d)
default: default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} { func (e *CharField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify CharField implement Fielder
var _ Fielder = new(CharField) var _ Fielder = new(CharField)
// A date, represented in go by a time.Time instance. // DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02 // only date values like 2006-01-02
// Has a few extra, optional attr tag: // Has a few extra, optional attr tag:
// //
@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"` // eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time { func (e DateField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the DateField's value
func (e *DateField) Set(d time.Time) { func (e *DateField) Set(d time.Time) {
*e = DateField(d) *e = DateField(d)
} }
// String convert datatime to string
func (e *DateField) String() string { func (e *DateField) String() string {
return e.Value().String() return e.Value().String()
} }
// FieldType return enum type Date
func (e *DateField) FieldType() int { func (e *DateField) FieldType() int {
return TypeDateField return TypeDateField
} }
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error { func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, format_Date) v, err := timeParse(d, formatDate)
if err != nil { if err != nil {
e.Set(v) e.Set(v)
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return Date value
func (e *DateField) RawValue() interface{} { func (e *DateField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify DateField implement fielder interface
var _ Fielder = new(DateField) var _ Fielder = new(DateField)
// A date, represented in go by a time.Time instance. // DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05 // datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField. // Takes the same extra arguments as DateField.
type DateTimeField time.Time type DateTimeField time.Time
// Value return the datatime value
func (e DateTimeField) Value() time.Time { func (e DateTimeField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) { func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d) *e = DateTimeField(d)
} }
// String return the time's String
func (e *DateTimeField) String() string { func (e *DateTimeField) String() string {
return e.Value().String() return e.Value().String()
} }
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int { func (e *DateTimeField) FieldType() int {
return TypeDateTimeField return TypeDateTimeField
} }
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error { func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, format_DateTime) v, err := timeParse(d, formatDateTime)
if err != nil { if err != nil {
e.Set(v) e.Set(v)
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} { func (e *DateTimeField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify datatime implement fielder
var _ Fielder = new(DateTimeField) var _ Fielder = new(DateTimeField)
// A floating-point number represented in go by a float32 value. // FloatField A floating-point number represented in go by a float32 value.
type FloatField float64 type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 { func (e FloatField) Value() float64 {
return float64(e) return float64(e)
} }
// Set the Float64
func (e *FloatField) Set(d float64) { func (e *FloatField) Set(d float64) {
*e = FloatField(d) *e = FloatField(d)
} }
// String return the string
func (e *FloatField) String() string { func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32) return ToStr(e.Value(), -1, 32)
} }
// FieldType return the enum type
func (e *FloatField) FieldType() int { func (e *FloatField) FieldType() int {
return TypeFloatField return TypeFloatField
} }
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error { func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case float32: case float32:
@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} { func (e *FloatField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify FloatField implement Fielder
var _ Fielder = new(FloatField) var _ Fielder = new(FloatField)
// -32768 to 32767 // SmallIntegerField -32768 to 32767
type SmallIntegerField int16 type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 { func (e SmallIntegerField) Value() int16 {
return int16(e) return int16(e)
} }
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) { func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d) *e = SmallIntegerField(d)
} }
// String convert smallint to string
func (e *SmallIntegerField) String() string { func (e *SmallIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int { func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField return TypeSmallIntegerField
} }
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error { func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int16: case int16:
@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} { func (e *SmallIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField) var _ Fielder = new(SmallIntegerField)
// -2147483648 to 2147483647 // IntegerField -2147483648 to 2147483647
type IntegerField int32 type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 { func (e IntegerField) Value() int32 {
return int32(e) return int32(e)
} }
// Set IntegerField value
func (e *IntegerField) Set(d int32) { func (e *IntegerField) Set(d int32) {
*e = IntegerField(d) *e = IntegerField(d)
} }
// String convert Int32 to string
func (e *IntegerField) String() string { func (e *IntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return the enum type
func (e *IntegerField) FieldType() int { func (e *IntegerField) FieldType() int {
return TypeIntegerField return TypeIntegerField
} }
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error { func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int32: case int32:
@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} { func (e *IntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField) var _ Fielder = new(IntegerField)
// -9223372036854775808 to 9223372036854775807. // BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64 type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 { func (e BigIntegerField) Value() int64 {
return int64(e) return int64(e)
} }
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) { func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d) *e = BigIntegerField(d)
} }
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string { func (e *BigIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *BigIntegerField) FieldType() int { func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField return TypeBigIntegerField
} }
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error { func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int64: case int64:
@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} { func (e *BigIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField) var _ Fielder = new(BigIntegerField)
// 0 to 65535 // PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16 type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 { func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e) return uint16(e)
} }
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) { func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d) *e = PositiveSmallIntegerField(d)
} }
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string { func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int { func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField return TypePositiveSmallIntegerField
} }
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint16: case uint16:
@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} { func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField) var _ Fielder = new(PositiveSmallIntegerField)
// 0 to 4294967295 // PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32 type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 { func (e PositiveIntegerField) Value() uint32 {
return uint32(e) return uint32(e)
} }
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) { func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d) *e = PositiveIntegerField(d)
} }
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string { func (e *PositiveIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int { func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField return TypePositiveIntegerField
} }
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error { func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint32: case uint32:
@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} { func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField) var _ Fielder = new(PositiveIntegerField)
// 0 to 18446744073709551615 // PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64 type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 { func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e) return uint64(e)
} }
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) { func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d) *e = PositiveBigIntegerField(d)
} }
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string { func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int { func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField return TypePositiveIntegerField
} }
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint64: case uint64:
@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} { func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField) var _ Fielder = new(PositiveBigIntegerField)
// A large text field. // TextField A large text field.
type TextField string type TextField string
// Value return TextField value
func (e TextField) Value() string { func (e TextField) Value() string {
return string(e) return string(e)
} }
// Set the TextField value
func (e *TextField) Set(d string) { func (e *TextField) Set(d string) {
*e = TextField(d) *e = TextField(d)
} }
// String convert TextField to string
func (e *TextField) String() string { func (e *TextField) String() string {
return e.Value() return e.Value()
} }
// FieldType return enum type
func (e *TextField) FieldType() int { func (e *TextField) FieldType() int {
return TypeTextField return TypeTextField
} }
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error { func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
e.Set(d) e.Set(d)
default: default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return TextField value
func (e *TextField) RawValue() interface{} { func (e *TextField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify TextField implement Fielder
var _ Fielder = new(TextField) var _ Fielder = new(TextField)

View File

@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool colDefault bool
initial StrTo initial StrTo
size int size int
auto_now bool autoNow bool
auto_now_add bool autoNowAdd bool
rel bool rel bool
reverse bool reverse bool
reverseField string reverseField string
@ -309,20 +309,20 @@ checkType:
if fi.rel && fi.dbcol { if fi.rel && fi.dbcol {
switch onDelete { switch onDelete {
case od_CASCADE, od_DO_NOTHING: case odCascade, odDoNothing:
case od_SET_DEFAULT: case odSetDefault:
if initial.Exist() == false { if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value") err = errors.New("on_delete: set_default need set field a default value")
goto end goto end
} }
case od_SET_NULL: case odSetNULL:
if fi.null == false { if fi.null == false {
err = errors.New("on_delete: set_null need set field null") err = errors.New("on_delete: set_null need set field null")
goto end goto end
} }
default: default:
if onDelete == "" { if onDelete == "" {
onDelete = od_CASCADE onDelete = odCascade
} else { } else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end goto end
@ -350,9 +350,9 @@ checkType:
fi.unique = false fi.unique = false
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] { if attrs["auto_now"] {
fi.auto_now = true fi.autoNow = true
} else if attrs["auto_now_add"] { } else if attrs["auto_now_add"] {
fi.auto_now_add = true fi.autoNowAdd = true
} }
case TypeFloatField: case TypeFloatField:
case TypeDecimalField: case TypeDecimalField:

View File

@ -15,7 +15,6 @@
package orm package orm
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi) added := info.fields.Add(fi)
if added == false { if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) err = fmt.Errorf("duplicate column name: %s", fi.column)
break break
} }
if fi.pk { if fi.pk {
if info.fields.pk != nil { if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only")) err = fmt.Errorf("one model must have one pk field only")
break break
} else { } else {
info.fields.pk = fi info.fields.pk = fi

View File

@ -76,21 +76,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField) var _ Fielder = new(SliceStringField)
// A json field. // A json field.
type JsonField struct { type JSONField struct {
Name string Name string
Data string Data string
} }
func (e *JsonField) String() string { func (e *JSONField) String() string {
data, _ := json.Marshal(e) data, _ := json.Marshal(e)
return string(data) return string(data)
} }
func (e *JsonField) FieldType() int { func (e *JSONField) FieldType() int {
return TypeTextField return TypeTextField
} }
func (e *JsonField) SetRaw(value interface{}) error { func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
return json.Unmarshal([]byte(d), e) return json.Unmarshal([]byte(d), e)
@ -99,14 +99,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
} }
} }
func (e *JsonField) RawValue() interface{} { func (e *JSONField) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(JsonField) var _ Fielder = new(JSONField)
type Data struct { type Data struct {
Id int ID int `orm:"column(id)"`
Boolean bool Boolean bool
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
@ -130,7 +130,7 @@ type Data struct {
} }
type DataNull struct { type DataNull struct {
Id int ID int `orm:"column(id)"`
Boolean bool `orm:"null"` Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"` Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"` Text string `orm:"null;type(text)"`
@ -193,7 +193,7 @@ type Float32 float64
type Float64 float64 type Float64 float64
type DataCustom struct { type DataCustom struct {
Id int ID int `orm:"column(id)"`
Boolean Boolean Boolean Boolean
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
@ -216,12 +216,12 @@ type DataCustom struct {
// only for mysql // only for mysql
type UserBig struct { type UserBig struct {
Id uint64 ID uint64 `orm:"column(id)"`
Name string Name string
} }
type User struct { type User struct {
Id int ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"` Email string `orm:"size(100)"`
Password string `orm:"size(100)"` Password string `orm:"size(100)"`
@ -235,9 +235,9 @@ type User struct {
ShouldSkip string `orm:"-"` ShouldSkip string `orm:"-"`
Nums int Nums int
Langs SliceStringField `orm:"size(100)"` Langs SliceStringField `orm:"size(100)"`
Extra JsonField `orm:"type(text)"` Extra JSONField `orm:"type(text)"`
unexport bool `orm:"-"` unexport bool `orm:"-"`
unexport_ bool unexportBool bool
} }
func (u *User) TableIndex() [][]string { func (u *User) TableIndex() [][]string {
@ -259,7 +259,7 @@ func NewUser() *User {
} }
type Profile struct { type Profile struct {
Id int ID int `orm:"column(id)"`
Age int16 Age int16
Money float64 Money float64
User *User `orm:"reverse(one)" json:"-"` User *User `orm:"reverse(one)" json:"-"`
@ -276,7 +276,7 @@ func NewProfile() *Profile {
} }
type Post struct { type Post struct {
Id int ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"` User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"` Title string `orm:"size(60)"`
Content string `orm:"type(text)"` Content string `orm:"type(text)"`
@ -297,7 +297,7 @@ func NewPost() *Post {
} }
type Tag struct { type Tag struct {
Id int ID int `orm:"column(id)"`
Name string `orm:"size(30)"` Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"` BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"` Posts []*Post `orm:"reverse(many)" json:"-"`
@ -309,7 +309,7 @@ func NewTag() *Tag {
} }
type PostTags struct { type PostTags struct {
Id int ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"` Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"` Tag *Tag `orm:"rel(fk)"`
} }
@ -319,7 +319,7 @@ func (m *PostTags) TableName() string {
} }
type Comment struct { type Comment struct {
Id int ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"` Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"` Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"` Parent *Comment `orm:"null;rel(fk)"`
@ -397,7 +397,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default") alias := getDbAlias("default")
if alias.Driver == DR_MySQL { if alias.Driver == DRMySQL {
alias.Engine = "INNODB" alias.Engine = "INNODB"
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage // Simple Usage
// //
// package main // package main
@ -59,12 +60,13 @@ import (
"time" "time"
) )
// DebugQueries define the debug
const ( const (
Debug_Queries = iota DebugQueries = iota
) )
// Define common vars
var ( var (
// DebugLevel = Debug_Queries
Debug = false Debug = false
DebugLog = NewLog(os.Stderr) DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000 DefaultRowsLimit = 1000
@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
// Params stores the Params
type Params map[string]interface{} type Params map[string]interface{}
// ParamsList stores paramslist
type ParamsList []interface{} type ParamsList []interface{}
type orm struct { type orm struct {
@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id) o.setPk(mi, ind, id)
cnt += 1 cnt++
} }
} else { } else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false) mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
// create new orm // NewOrm create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o return o
} }
// create a new ormer object with specify *sql.DB for query // NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias var al *alias

View File

@ -19,6 +19,7 @@ import (
"strings" "strings"
) )
// ExprSep define the expression seperation
const ( const (
ExprSep = "__" ExprSep = "__"
) )
@ -32,19 +33,19 @@ type condValue struct {
isCond bool isCond bool
} }
// condition struct. // Condition struct.
// work for WHERE conditions. // work for WHERE conditions.
type Condition struct { type Condition struct {
params []condValue params []condValue
} }
// return new condition struct // NewCondition return new condition struct
func NewCondition() *Condition { func NewCondition() *Condition {
c := &Condition{} c := &Condition{}
return c return c
} }
// add expression to condition // And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty")) panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add NOT expression to condition // AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a condition to current condition // AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
// add OR expression to condition // Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty")) panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add OR NOT expression to condition // OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a OR condition to current condition // OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c return c
} }
// check the condition arguments are empty or not. // IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool { func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
// clone a condition // clone clone a condition
func (c Condition) clone() *Condition { func (c Condition) clone() *Condition {
return &c return &c
} }

View File

@ -23,11 +23,12 @@ import (
"time" "time"
) )
// Log implement the log.Logger
type Log struct { type Log struct {
*log.Logger *log.Logger
} }
// set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", 1e9)
@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil { if err != nil {
flag = "FAIL" flag = "FAIL"
} }
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query) con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args)) cons := make([]string, 0, len(args))
for _, arg := range args { for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg)) cons = append(cons, fmt.Sprintf("%v", arg))

View File

@ -25,11 +25,12 @@ type colValue struct {
type operator int type operator int
// define Col operations
const ( const (
Col_Add operator = iota ColAdd operator = iota
Col_Minus ColMinus
Col_Multiply ColMultiply
Col_Except ColExcept
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: // ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@ -38,7 +39,7 @@ const (
// } // }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except: case ColAdd, ColMinus, ColMultiply, ColExcept:
default: default:
panic(fmt.Errorf("orm.ColValue wrong operator")) panic(fmt.Errorf("orm.ColValue wrong operator"))
} }

View File

@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" { if str != "" {
if len(str) >= 19 { if len(str) >= 19 {
str = str[:19] str = str[:19]
t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ) t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil { if err == nil {
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
} else if len(str) >= 10 { } else if len(str) >= 10 {
str = str[:10] str = str[:10]
t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc) t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil { if err == nil {
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container // query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error { func (o *rawSet) QueryRow(containers ...interface{}) error {
refs := make([]interface{}, 0, len(containers)) var (
sInds := make([]reflect.Value, 0) refs = make([]interface{}, 0, len(containers))
eTyps := make([]reflect.Type, 0) sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container // query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0, len(containers)) var (
sInds := make([]reflect.Value, 0) refs = make([]interface{}, 0, len(containers))
eTyps := make([]reflect.Type, 0) sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
sInd := reflect.Indirect(val) sInd := reflect.Indirect(val)
@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil { rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
defer rs.Close() defer rs.Close()
@ -574,9 +575,10 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { columns, err := rs.Columns()
if err != nil {
return 0, err return 0, err
} else { }
if len(needCols) > 0 { if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols)) indexs = make([]int, 0, len(needCols))
} else { } else {
@ -600,7 +602,6 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
} }
} }
} }
}
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(refs...); err != nil {
return 0, err return 0, err
@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows rs, err := o.orm.db.Query(query, args...)
if r, err := o.orm.db.Query(query, args...); err != nil { if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
defer rs.Close() defer rs.Close()
@ -706,16 +705,16 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { columns, err := rs.Columns()
if err != nil {
return 0, err return 0, err
} else { }
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i := range refs { for i := range refs {
if keyCol == cols[i] { if keyCol == cols[i] {
keyIndex = i keyIndex = i
} }
if typ == 1 || keyIndex == i { if typ == 1 || keyIndex == i {
var ref sql.NullString var ref sql.NullString
refs[i] = &ref refs[i] = &ref
@ -723,17 +722,14 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
var ref interface{} var ref interface{}
refs[i] = &ref refs[i] = &ref
} }
if valueCol == cols[i] { if valueCol == cols[i] {
valueIndex = i valueIndex = i
} }
} }
if keyIndex == -1 || valueIndex == -1 { if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
} }
} }
}
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(refs...); err != nil {
return 0, err return 0, err

View File

@ -31,13 +31,13 @@ import (
var _ = os.PathSeparator var _ = os.PathSeparator
var ( var (
test_Date = format_Date + " -0700" testDate = formatDate + " -0700"
test_DateTime = format_DateTime + " -0700" testDateTime = formatDateTime + " -0700"
) )
func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) { func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 { if len(args) == 0 {
return fmt.Errorf("miss args"), false return false, fmt.Errorf("miss args")
} }
b := args[0] b := args[0]
arg := argAny(args) arg := argAny(args)
@ -71,21 +71,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg: wrongArg:
if err != nil { if err != nil {
return err, false return false, err
} }
return nil, true return true, nil
} }
func AssertIs(a interface{}, args ...interface{}) error { func AssertIs(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(true, a, args...); ok == false { if ok, err := ValuesCompare(true, a, args...); ok == false {
return err return err
} }
return nil return nil
} }
func AssertNot(a interface{}, args ...interface{}) error { func AssertNot(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(false, a, args...); ok == false { if ok, err := ValuesCompare(false, a, args...); ok == false {
return err return err
} }
return nil return nil
@ -208,7 +208,7 @@ func TestModelSyntax(t *testing.T) {
} }
} }
var Data_Values = map[string]interface{}{ var DataValues = map[string]interface{}{
"Boolean": true, "Boolean": true,
"Char": "char", "Char": "char",
"Text": "text", "Text": "text",
@ -235,7 +235,7 @@ func TestDataTypes(t *testing.T) {
d := Data{} d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value)) e.Set(reflect.ValueOf(value))
} }
@ -244,22 +244,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = Data{Id: 1} d = Data{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d)) ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -278,7 +278,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = DataNull{Id: 1} d = DataNull{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -309,7 +309,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
d = DataNull{Id: 2} d = DataNull{ID: 2}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -362,7 +362,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 3)) throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3} d = DataNull{ID: 3}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -402,7 +402,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{} d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
if !e.IsValid() { if !e.IsValid() {
continue continue
@ -414,13 +414,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1} d = DataCustom{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d)) ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
if !e.IsValid() { if !e.IsValid() {
continue continue
@ -451,7 +451,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
u := &User{Id: user.Id} u := &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, err) throwFail(t, err)
@ -461,8 +461,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true)) throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date)) throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime)) throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie" user.UserName = "astaxie"
user.Profile = profile user.Profile = profile
@ -470,11 +470,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFailNow(t, err) throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie")) throwFail(t, AssertIs(u.UserName, "astaxie"))
throwFail(t, AssertIs(u.Profile.Id, profile.Id)) throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"} u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName") err = dORM.Read(u, "UserName")
@ -487,7 +487,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFailNow(t, err) throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ")) throwFail(t, AssertIs(u.UserName, "QQ"))
@ -497,7 +497,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil)) throwFail(t, AssertIs(true, u.Profile == nil))
@ -506,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: 100} u = &User{ID: 100}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows)) throwFail(t, AssertIs(err, ErrNoRows))
@ -516,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
ub = UserBig{Id: 1} ub = UserBig{ID: 1}
err = dORM.Read(&ub) err = dORM.Read(&ub)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name")) throwFail(t, AssertIs(ub.Name, "name"))
@ -586,7 +586,7 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4)) throwFail(t, AssertIs(id, 4))
tags := []*Tag{ tags := []*Tag{
{Name: "golang", BestPost: &Post{Id: 2}}, {Name: "golang", BestPost: &Post{ID: 2}},
{Name: "example"}, {Name: "example"},
{Name: "format"}, {Name: "format"},
{Name: "c++"}, {Name: "c++"},
@ -638,7 +638,7 @@ The program—and web server—godoc processes Go source files to extract docume
} }
func TestCustomField(t *testing.T) { func TestCustomField(t *testing.T) {
user := User{Id: 2} user := User{ID: 2}
err := dORM.Read(&user) err := dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
@ -648,7 +648,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra") _, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err) throwFailNow(t, err)
user = User{Id: 2} user = User{ID: 2}
err = dORM.Read(&user) err = dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2)) throwFailNow(t, AssertIs(len(user.Langs), 2))
@ -889,9 +889,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene")) throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
throwFailNow(t, AssertIs(users2[0].Id, 0)) throwFailNow(t, AssertIs(users2[0].ID, 0))
throwFailNow(t, AssertIs(users2[1].Id, 0)) throwFailNow(t, AssertIs(users2[1].ID, 0))
throwFailNow(t, AssertIs(users2[2].Id, 0)) throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@ -1112,7 +1112,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) { func TestLoadRelated(t *testing.T) {
// load reverse foreign key // load reverse foreign key
user := User{Id: 3} user := User{ID: 3}
err := dORM.Read(&user) err := dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
@ -1121,7 +1121,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true) num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err) throwFailNow(t, err)
@ -1143,8 +1143,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one // load reverse one to one
profile := Profile{Id: 3} profile := Profile{ID: 3}
profile.BestPost = &Post{Id: 2} profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost") num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
@ -1183,7 +1183,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
post := Post{Id: 2} post := Post{ID: 2}
// load rel foreign key // load rel foreign key
err = dORM.Read(&post) err = dORM.Read(&post)
@ -1204,7 +1204,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m // load rel m2m
post = Post{Id: 2} post = Post{ID: 2}
err = dORM.Read(&post) err = dORM.Read(&post)
throwFailNow(t, err) throwFailNow(t, err)
@ -1224,7 +1224,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m // load reverse m2m
tag := Tag{Id: 1} tag := Tag{ID: 1}
err = dORM.Read(&tag) err = dORM.Read(&tag)
throwFailNow(t, err) throwFailNow(t, err)
@ -1233,19 +1233,19 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true) num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
} }
func TestQueryM2M(t *testing.T) { func TestQueryM2M(t *testing.T) {
post := Post{Id: 4} post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags") m2m := dORM.QueryM2M(&post, "Tags")
tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
@ -1319,7 +1319,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts { for _, post := range posts {
p := post.(*Post) p := post.(*Post)
p.User = &User{Id: 1} p.User = &User{ID: 1}
_, err := dORM.Insert(post) _, err := dORM.Insert(post)
throwFailNow(t, err) throwFailNow(t, err)
} }
@ -1459,10 +1459,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64 Decimal float64
) )
data_values := make(map[string]interface{}, len(Data_Values)) dataValues := make(map[string]interface{}, len(DataValues))
for k, v := range Data_Values { for k, v := range DataValues {
data_values[strings.ToLower(k)] = v dataValues[strings.ToLower(k)] = v
} }
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
@ -1488,14 +1488,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
case "date": case "date":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date)) throwFail(t, AssertIs(v, value, testDate))
case "datetime": case "datetime":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime)) throwFail(t, AssertIs(v, value, testDateTime))
default: default:
throwFail(t, AssertIs(v, data_values[col])) throwFail(t, AssertIs(v, dataValues[col]))
} }
} }
@ -1529,16 +1529,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0])) ind := reflect.Indirect(reflect.ValueOf(datas[0]))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -1553,16 +1553,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0])) ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -1699,25 +1699,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Add, 100), "Nums": ColValue(ColAdd, 100),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Minus, 50), "Nums": ColValue(ColMinus, 50),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Multiply, 3), "Nums": ColValue(ColMultiply, 3),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Except, 5), "Nums": ColValue(ColExcept, 5),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -1838,15 +1838,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7)) throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false)) throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true)) throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date)) throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime)) throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName") created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(created, false)) throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id)) throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.Id)) throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName)) throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password)) throwFail(t, AssertIs(nu.Password, u.Password))

View File

@ -16,6 +16,7 @@ package orm
import "errors" import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface { type QueryBuilder interface {
Select(fields ...string) QueryBuilder Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder From(tables ...string) QueryBuilder
@ -43,15 +44,16 @@ type QueryBuilder interface {
String() string String() string
} }
// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" { if driver == "mysql" {
qb = new(MySQLQueryBuilder) qb = new(MySQLQueryBuilder)
} else if driver == "postgres" { } else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet!") err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" { } else if driver == "sqlite" {
err = errors.New("sqlite query builder is not supported yet!") err = errors.New("sqlite query builder is not supported yet")
} else { } else {
err = errors.New("unknown driver for query builder!") err = errors.New("unknown driver for query builder")
} }
return return
} }

View File

@ -20,134 +20,160 @@ import (
"strings" "strings"
) )
const COMMA_SPACE = ", " // CommaSpace is the seperation
const CommaSpace = ", "
// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct { type MySQLQueryBuilder struct {
Tokens []string Tokens []string
} }
// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb return qb
} }
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb return qb
} }
// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table) qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb return qb
} }
// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb return qb
} }
// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb return qb
} }
// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond) qb.Tokens = append(qb.Tokens, "ON", cond)
return qb return qb
} }
// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond) qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb return qb
} }
// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond) qb.Tokens = append(qb.Tokens, "AND", cond)
return qb return qb
} }
// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond) qb.Tokens = append(qb.Tokens, "OR", cond)
return qb return qb
} }
// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")") qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb return qb
} }
// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb return qb
} }
// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder { func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC") qb.Tokens = append(qb.Tokens, "ASC")
return qb return qb
} }
// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder { func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC") qb.Tokens = append(qb.Tokens, "DESC")
return qb return qb
} }
// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb return qb
} }
// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb return qb
} }
// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb return qb
} }
// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond) qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb return qb
} }
// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb return qb
} }
// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb return qb
} }
// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE") qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 { if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
} }
return qb return qb
} }
// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table) qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 { if len(fields) != 0 {
fieldsStr := strings.Join(fields, COMMA_SPACE) fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
} }
return qb return qb
} }
// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, COMMA_SPACE) valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb return qb
} }
// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias) return fmt.Sprintf("(%s) AS %s", sub, alias)
} }
// String join all Tokens
func (qb *MySQLQueryBuilder) String() string { func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ") return strings.Join(qb.Tokens, " ")
} }

View File

@ -20,13 +20,13 @@ import (
"time" "time"
) )
// database driver // Driver define database driver
type Driver interface { type Driver interface {
Name() string Name() string
Type() DriverType Type() DriverType
} }
// field info // Fielder define field info
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -34,7 +34,7 @@ type Fielder interface {
RawValue() interface{} RawValue() interface{}
} }
// orm struct // Ormer define the orm interface
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error Read(interface{}, ...string) error
ReadOrCreate(interface{}, string, ...string) (bool, int64, error) ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
@ -53,13 +53,13 @@ type Ormer interface {
Driver() Driver Driver() Driver
} }
// insert prepared statement // Inserter insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
// query seter // QuerySeter query seter
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
@ -84,7 +84,7 @@ type QuerySeter interface {
RowsToStruct(interface{}, string, string) (int64, error) RowsToStruct(interface{}, string, string) (int64, error)
} }
// model to model query struct // QueryM2Mer model to model query struct
type QueryM2Mer interface { type QueryM2Mer interface {
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
@ -93,13 +93,13 @@ type QueryM2Mer interface {
Count() (int64, error) Count() (int64, error)
} }
// raw query statement // RawPreparer raw query statement
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (sql.Result, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
// raw query seter // RawSeter raw query seter
type RawSeter interface { type RawSeter interface {
Exec() (sql.Result, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error QueryRow(...interface{}) error
@ -113,7 +113,7 @@ type RawSeter interface {
Prepare() (RawPreparer, error) Prepare() (RawPreparer, error)
} }
// statement querier // stmtQuerier statement querier
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
@ -162,8 +162,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string OperatorSQL(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)

View File

@ -22,9 +22,10 @@ import (
"time" "time"
) )
// StrTo is the target string
type StrTo string type StrTo string
// set string // Set string
func (f *StrTo) Set(v string) { func (f *StrTo) Set(v string) {
if v != "" { if v != "" {
*f = StrTo(v) *f = StrTo(v)
@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
} }
} }
// clean string // Clear string
func (f *StrTo) Clear() { func (f *StrTo) Clear() {
*f = StrTo(0x1E) *f = StrTo(0x1E)
} }
// check string exist // Exist check string exist
func (f StrTo) Exist() bool { func (f StrTo) Exist() bool {
return string(f) != string(0x1E) return string(f) != string(0x1E)
} }
// string to bool // Bool string to bool
func (f StrTo) Bool() (bool, error) { func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String()) return strconv.ParseBool(f.String())
} }
// string to float32 // Float32 string to float32
func (f StrTo) Float32() (float32, error) { func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32) v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err return float32(v), err
} }
// string to float64 // Float64 string to float64
func (f StrTo) Float64() (float64, error) { func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
// string to int // Int string to int
func (f StrTo) Int() (int, error) { func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err return int(v), err
} }
// string to int8 // Int8 string to int8
func (f StrTo) Int8() (int8, error) { func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8) v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err return int8(v), err
} }
// string to int16 // Int16 string to int16
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
} }
// string to int32 // Int32 string to int32
func (f StrTo) Int32() (int32, error) { func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err return int32(v), err
} }
// string to int64 // Int64 string to int64
func (f StrTo) Int64() (int64, error) { func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64) v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err return int64(v), err
} }
// string to uint // Uint string to uint
func (f StrTo) Uint() (uint, error) { func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err return uint(v), err
} }
// string to uint8 // Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) { func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8) v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err return uint8(v), err
} }
// string to uint16 // Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
} }
// string to uint31 // Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) { func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err return uint32(v), err
} }
// string to uint64 // Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) { func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64) v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err return uint64(v), err
} }
// string to string // String string to string
func (f StrTo) String() string { func (f StrTo) String() string {
if f.Exist() { if f.Exist() {
return string(f) return string(f)
@ -127,7 +128,7 @@ func (f StrTo) String() string {
return "" return ""
} }
// interface to string // ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) { func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) { switch v := value.(type) {
case bool: case bool:
@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s return s
} }
// interface to int64 // ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) { func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value) val := reflect.ValueOf(value)
switch value.(type) { switch value.(type) {