1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 05:21:30 +00:00

some fix / add test

This commit is contained in:
slene 2013-08-07 19:11:44 +08:00
parent 10f4e822c3
commit 46668b811f
15 changed files with 1082 additions and 222 deletions

136
orm/db.go
View File

@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) {
switch { switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk[0].column c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel { for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo { if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column c2 = ffi.column
@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) {
} }
default: default:
c1 = jt.fi.column c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk[0].column c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse { if jt.fi.reverse {
c1 = jt.mi.fields.pk[0].column c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column c2 = jt.fi.reverseFieldInfo.column
} }
} }
@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
if fi.reverseFieldInfo.fieldType == RelManyToMany { if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo mmi = fi.reverseFieldInfo.relThroughModelInfo
} }
default:
return
} }
jt, _ := d.add(names, mmi, fi, fi.null == false) jt, _ := d.add(names, mmi, fi, fi.null == false)
@ -434,40 +436,36 @@ type dbBase struct {
ins dbBaser ins dbBaser
} }
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) { func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
exist := true
columns := make([]string, 0, len(mi.fields.pk)) fi := mi.fields.pk
values := make([]interface{}, 0, len(mi.fields.pk))
for _, fi := range mi.fields.pk {
v := ind.Field(fi.fieldIndex) v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 { if fi.fieldType&IsIntegerField > 0 {
vu := v.Int() vu := v.Int()
if exist {
exist = vu > 0 exist = vu > 0
} value = vu
values = append(values, vu)
} else { } else {
vu := v.String() vu := v.String()
if exist {
exist = vu != "" exist = vu != ""
value = vu
} }
values = append(values, vu)
} column = fi.column
columns = append(columns, fi.column)
} return
return columns, values, exist
} }
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValues, _ := d.existPk(mi, ind) _, pkValue, _ := d.existPk(mi, ind)
for _, column := range mi.fields.orders { for _, column := range mi.fields.orders {
fi := mi.fields.columns[column] fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto { if fi.dbcol == false || fi.auto && skipAuto {
continue continue
} }
var value interface{} var value interface{}
if i, ok := mi.fields.pk.Exist(fi); ok { if fi.pk {
value = pkValues[i] value = pkValue
} else { } else {
field := ind.Field(fi.fieldIndex) field := ind.Field(fi.fieldIndex)
if fi.isFielder { if fi.isFielder {
@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if field.IsNil() { if field.IsNil() {
value = nil value = nil
} else { } else {
_, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field)) if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
if fok { value = vu
value = fvalues[0]
} else { } else {
value = nil value = nil
} }
@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i
} }
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
pkNames, pkValues, ok := d.existPk(mi, ind) pkColumn, pkValue, ok := d.existPk(mi, ind)
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
} }
pkColumns := strings.Join(pkNames, "` = ? AND `")
sels := strings.Join(mi.fields.dbcols, "`, `") sels := strings.Join(mi.fields.dbcols, "`, `")
colsNum := len(mi.fields.dbcols) colsNum := len(mi.fields.dbcols)
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns) query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i, _ := range refs { for i, _ := range refs {
@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
refs[i] = &ref refs[i] = &ref
} }
row := q.QueryRow(query, pkValues...) row := q.QueryRow(query, pkValue)
if err := row.Scan(refs...); err != nil { if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err return err
} else { } else {
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.addrField.Elem().Type())
@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
} }
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkNames, pkValues, ok := d.existPk(mi, ind) pkName, pkValue, ok := d.existPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, err return 0, err
} }
pkColumns := strings.Join(pkNames, "` = ? AND `")
setColumns := strings.Join(setNames, "` = ?, `") setColumns := strings.Join(setNames, "` = ?, `")
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns) query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
setValues = append(setValues, pkValues...) setValues = append(setValues, pkValue)
if res, err := q.Exec(query, setValues...); err == nil { if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected() return res.RowsAffected()
@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
} }
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, ok := d.existPk(mi, ind) pkName, pkValue, ok := d.existPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
columns := strings.Join(names, "` = ? AND `") query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) if res, err := q.Exec(query, pkValue); err == nil {
if res, err := q.Exec(query, values...); err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
@ -660,17 +655,15 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
} }
if num > 0 { if num > 0 {
if mi.fields.auto != nil { if mi.fields.pk.auto {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
if len(names) == 1 { err := d.deleteRels(q, mi, []interface{}{pkValue})
err := d.deleteRels(q, mi, values)
if err != nil { if err != nil {
return num, err return num, err
} }
} }
}
return num, err return num, err
} else { } else {
@ -683,13 +676,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
for col, val := range params { for col, val := range params {
column := snakeString(col) if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false { panic(fmt.Sprintf("wrong field/column name `%s`", col))
panic(fmt.Sprintf("wrong field/column name `%s`", column)) } else {
} columns = append(columns, fi.column)
columns = append(columns, column)
values = append(values, val) values = append(values, val)
} }
}
if len(columns) == 0 { if len(columns) == 0 {
panic("update params cannot empty") panic("update params cannot empty")
@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
switch fi.onDelete { switch fi.onDelete {
case od_CASCADE: case od_CASCADE:
cond := NewCondition() cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond) _, err := d.DeleteBatch(q, nil, fi.mi, cond)
if err != nil { if err != nil {
return err return err
} }
case od_SET_DEFAULT, od_SET_NULL: case od_SET_DEFAULT, od_SET_NULL:
cond := NewCondition() cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil} params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT { if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String() params[fi.column] = fi.initial.String()
@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
where, args := tables.getCondSql(cond, false) where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql() join := tables.getJoinSql()
colsNum := len(mi.fields.pk) cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
cols := make([]string, colsNum) query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
for i, fi := range mi.fields.pk {
cols[i] = fi.column
}
colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
var rs *sql.Rows var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil { if r, err := q.Query(query, args...); err != nil {
@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r rs = r
} }
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref interface{} var ref interface{}
refs[i] = &ref
}
args = make([]interface{}, 0) args = make([]interface{}, 0)
cnt := 0 cnt := 0
for rs.Next() { for rs.Next() {
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(&ref); err != nil {
return 0, err return 0, err
} }
for _, ref := range refs { args = append(args, reflect.ValueOf(ref).Interface())
args = append(args, reflect.ValueOf(ref).Elem().Interface())
}
cnt++ cnt++
} }
@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
if colsNum > 1 { sql, args := d.ins.GetOperatorSql(mi, "in", args)
columns := strings.Join(cols, "` = ? AND `") query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
} else {
var sql string
sql, args = d.ins.GetOperatorSql(mi, "in", args)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
}
if res, err := q.Exec(query, args...); err == nil { if res, err := q.Exec(query, args...); err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, err return 0, err
} }
if colsNum == 1 && num > 0 { if num > 0 {
err := d.deleteRels(q, mi, args) err := d.deleteRels(q, mi, args)
if err != nil { if err != nil {
return num, err return num, err
@ -980,16 +954,14 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
copy(params, args) copy(params, args)
sql := "" sql := ""
for i, arg := range args { for i, arg := range args {
if len(mi.fields.pk) == 1 {
if md, ok := arg.(Modeler); ok { if md, ok := arg.(Modeler); ok {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
if _, values, exist := d.existPk(mi, ind); exist { if _, vu, exist := d.existPk(mi, ind); exist {
arg = values[0] arg = vu
} else { } else {
panic(fmt.Sprintf("`%s` need a valid args value", operator)) panic(fmt.Sprintf("`%s` need a valid args value", operator))
} }
} }
}
params[i] = arg params[i] = arg
} }
if operator == "in" { if operator == "in" {
@ -1175,7 +1147,7 @@ setValue:
value = v value = v
} }
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
goto setValue goto setValue
} }
@ -1236,12 +1208,12 @@ setValue:
} }
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
if value != nil { if value != nil {
fieldType = fi.relModelInfo.fields.pk[0].fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler) md := mf.Interface().(Modeler)
md.Init(md) md.Init(md)
field.Set(mf) field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
field = &f field = &f
goto setValue goto setValue
} }

View File

@ -9,24 +9,37 @@ import (
const defaultMaxIdle = 30 const defaultMaxIdle = 30
type driverType int type DriverType int
const ( const (
_ driverType = iota _ DriverType = iota
DR_MySQL DR_MySQL
DR_Sqlite DR_Sqlite
DR_Oracle DR_Oracle
DR_Postgres DR_Postgres
) )
type driver string
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
func (d driver) Name() string {
return string(d)
}
var _ Driver = new(driver)
var ( var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)} dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]driverType{ drivers = map[string]DriverType{
"mysql": DR_MySQL, "mysql": DR_MySQL,
"postgres": DR_Postgres, "postgres": DR_Postgres,
"sqlite3": DR_Sqlite, "sqlite3": DR_Sqlite,
} }
dbBasers = map[driverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(), DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(), DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(), DR_Oracle: newdbBaseMysql(),
@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) {
type alias struct { type alias struct {
Name string Name string
Driver DriverType
DriverName string DriverName string
DataSource string DataSource string
MaxIdle int MaxIdle int
@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
if dr, ok := drivers[driverName]; ok { if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr] al.DbBaser = dbBasers[dr]
al.Driver = dr
} else { } else {
err = fmt.Errorf("driver name `%s` have not registered", driverName) err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end goto end
@ -116,7 +131,7 @@ end:
} }
} }
func RegisterDriver(name string, typ driverType) { func RegisterDriver(name string, typ DriverType) {
if t, ok := drivers[name]; ok == false { if t, ok := drivers[name]; ok == false {
drivers[name] = typ drivers[name] = typ
} else { } else {

View File

@ -49,6 +49,7 @@ type _modelCache struct {
sync.RWMutex sync.RWMutex
orders []string orders []string
cache map[string]*modelInfo cache map[string]*modelInfo
done bool
} }
func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) all() map[string]*modelInfo {

View File

@ -8,7 +8,7 @@ import (
"strings" "strings"
) )
func RegisterModel(model Modeler) { func registerModel(model Modeler) {
info := newModelInfo(model) info := newModelInfo(model)
model.Init(model) model.Init(model)
table := model.GetTableName() table := model.GetTableName()
@ -27,9 +27,10 @@ func RegisterModel(model Modeler) {
modelCache.set(table, info) modelCache.set(table, info)
} }
func BootStrap() { func bootStrap() {
modelCache.Lock() if modelCache.done {
defer modelCache.Unlock() return
}
var ( var (
err error err error
@ -59,14 +60,6 @@ func BootStrap() {
} }
fi.relModelInfo = mii fi.relModelInfo = mii
if fi.rel {
if mii.fields.pk.IsMulti() {
err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
goto end
}
}
switch fi.fieldType { switch fi.fieldType {
case RelManyToMany: case RelManyToMany:
if fi.relThrough != "" { if fi.relThrough != "" {
@ -207,6 +200,25 @@ end:
fmt.Println(err) fmt.Println(err)
os.Exit(2) os.Exit(2)
} }
}
runCommand()
func RegisterModel(models ...Modeler) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
}
for _, model := range models {
registerModel(model)
}
}
func BootStrap() {
if modelCache.done {
return
}
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
modelCache.done = true
} }

View File

@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices {
return *f return *f
} }
type primaryKeys []*fieldInfo
func (p *primaryKeys) Add(fi *fieldInfo) {
*p = append(*p, fi)
}
func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
for i, v := range p {
if v == fi {
return i, true
}
}
return -1, false
}
func (p primaryKeys) IsMulti() bool {
return len(p) > 1
}
func (p primaryKeys) IsEmpty() bool {
return len(p) == 0
}
type fields struct { type fields struct {
pk primaryKeys pk *fieldInfo
auto *fieldInfo
columns map[string]*fieldInfo columns map[string]*fieldInfo
fields map[string]*fieldInfo fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo fieldsLow map[string]*fieldInfo

View File

@ -50,29 +50,24 @@ func newModelInfo(model Modeler) (info *modelInfo) {
if err != nil { if err != nil {
break break
} }
added := info.fields.Add(fi) added := info.fields.Add(fi)
if added == false { if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
break break
} }
if fi.pk { if fi.pk {
if info.fields.pk != nil { if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only")) err = errors.New(fmt.Sprintf("one model must have one pk field only"))
break break
} else { } else {
info.fields.pk.Add(fi) info.fields.pk = fi
} }
} }
if fi.auto {
info.fields.auto = fi
}
fi.fieldIndex = i
fi.mi = info
}
if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false { fi.fieldIndex = i
err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field")) fi.mi = info
goto end
} }
if err != nil { if err != nil {
@ -80,11 +75,6 @@ func newModelInfo(model Modeler) (info *modelInfo) {
os.Exit(2) os.Exit(2)
} }
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
return return
} }
@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info.fields.Add(fa) info.fields.Add(fa)
info.fields.Add(f1) info.fields.Add(f1)
info.fields.Add(f2) info.fields.Add(f2)
info.fields.pk.Add(fa) info.fields.pk = fa
return return
} }

152
orm/models_test.go Normal file
View File

@ -0,0 +1,152 @@
package orm
import (
"fmt"
"os"
"time"
_ "github.com/bmizerany/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
)
type User struct {
Id int `orm:"auto"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16 `orm:"choices(0,1,2,3);defalut(0)"`
IsStaff bool `orm:"default(false)"`
IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewUser() *User {
obj := new(User)
obj.Manager.Init(obj)
return obj
}
type Profile struct {
Id int `orm:"auto"`
Age int16 ``
Money float64 ``
User *User `orm:"reverse(one)" json:"-"`
Manager `json:"-"`
}
func (u *Profile) TableName() string {
return "user_profile"
}
func NewProfile() *Profile {
obj := new(Profile)
obj.Manager.Init(obj)
return obj
}
type Post struct {
Id int `orm:"auto"`
User *User `orm:"rel(fk)"` //
Title string `orm:"size(60)"`
Content string ``
Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m)"`
Manager `json:"-"`
}
func NewPost() *Post {
obj := new(Post)
obj.Manager.Init(obj)
return obj
}
type Tag struct {
Id int `orm:"auto"`
Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewTag() *Tag {
obj := new(Tag)
obj.Manager.Init(obj)
return obj
}
type Comment struct {
Id int `orm:"auto"`
Post *Post `orm:"rel(fk)"`
Content string ``
Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"`
Manager `json:"-"`
}
func NewComment() *Comment {
obj := new(Comment)
obj.Manager.Init(obj)
return obj
}
var DBARGS = struct {
Driver string
Source string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
}
var dORM Ormer
func init() {
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(`need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/bmizerany/pq
eg: mysql
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
`)
os.Exit(2)
}
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
BootStrap()
truncateTables()
dORM = NewOrm()
}
func truncateTables() {
logs := "truncate tables for test\n"
o := NewOrm()
for _, m := range modelCache.allOrdered() {
query := fmt.Sprintf("truncate table `%s`", m.table)
_, err := o.Raw(query).Exec()
logs += query + "\n"
if err != nil {
fmt.Println(logs)
fmt.Println(err)
os.Exit(2)
}
}
}

View File

@ -9,13 +9,15 @@ import (
) )
var ( var (
ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
DefaultRowsLimit = 1000 DefaultRowsLimit = 1000
DefaultRelsDepth = 5 DefaultRelsDepth = 5
DefaultTimeLoc = time.Local DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> not row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrNotImplement = errors.New("have not implement")
) )
type Params map[string]interface{} type Params map[string]interface{}
@ -27,13 +29,15 @@ type orm struct {
isTx bool isTx bool
} }
var _ Ormer = new(orm)
func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) {
md.Init(md, true) md.Init(md, true)
name := md.GetTableName() name := md.GetTableName()
if mi, ok := modelCache.get(name); ok { if mi, ok := modelCache.get(name); ok {
return mi, reflect.Indirect(reflect.ValueOf(md)) return mi, reflect.Indirect(reflect.ValueOf(md))
} }
panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name)) panic(fmt.Sprintf("<orm> table name: `%s` not exists", name))
} }
func (o *orm) Read(md Modeler) error { func (o *orm) Read(md Modeler) error {
@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) {
return id, err return id, err
} }
if id > 0 { if id > 0 {
if mi.fields.auto != nil { if mi.fields.pk.auto {
ind.Field(mi.fields.auto.fieldIndex).SetInt(id) ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
return id, nil return id, nil
@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.auto != nil { if mi.fields.pk.auto {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
} }
return num, nil return num, nil
} }
func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) LoadRel(md Modeler, name string) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
@ -111,7 +133,7 @@ func (o *orm) Using(name string) error {
func (o *orm) Begin() error { func (o *orm) Begin() error {
if o.isTx { if o.isTx {
return ErrTXHasBegin return ErrTxHasBegan
} }
tx, err := o.alias.DB.Begin() tx, err := o.alias.DB.Begin()
if err != nil { if err != nil {
@ -124,24 +146,28 @@ func (o *orm) Begin() error {
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if o.isTx == false {
return ErrTXNotBegin return ErrTxDone
} }
err := o.db.(*sql.Tx).Commit() err := o.db.(*sql.Tx).Commit()
if err == nil { if err == nil {
o.isTx = false o.isTx = false
o.db = o.alias.DB o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
} }
return err return err
} }
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if o.isTx == false {
return ErrTXNotBegin return ErrTxDone
} }
err := o.db.(*sql.Tx).Rollback() err := o.db.(*sql.Tx).Rollback()
if err == nil { if err == nil {
o.isTx = false o.isTx = false
o.db = o.alias.DB o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
} }
return err return err
} }
@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args) return newRawSet(o, query, args)
} }
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once
o := new(orm) o := new(orm)
err := o.Using("default") err := o.Using("default")
if err != nil { if err != nil {

View File

@ -26,23 +26,24 @@ func NewCondition() *Condition {
return c return c
} }
func (c *Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.And> args cannot empty") panic("<Condition.And> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return c return &c
} }
func (c *Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.AndNot> args cannot empty") panic("<Condition.AndNot> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return c return &c
} }
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond { if c == cond {
panic("cannot use self as sub cond") panic("cannot use self as sub cond")
} }
@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
func (c *Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.Or> args cannot empty") panic("<Condition.Or> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return c return &c
} }
func (c *Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.OrNot> args cannot empty") panic("<Condition.OrNot> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return c return &c
} }
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond { if c == cond {
panic("cannot use self as sub cond") panic("cannot use self as sub cond")
} }
@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
func (c Condition) Clone() *Condition { func (c Condition) clone() *Condition {
params := c.params
c.params = make([]condValue, len(params))
copy(c.params, params)
return &c return &c
} }
func (c *Condition) Merge() (expr string, args []interface{}) {
return expr, args
}

View File

@ -13,6 +13,8 @@ type insertSet struct {
closed bool closed bool
} }
var _ Inserter = new(insertSet)
func (o *insertSet) Insert(md Modeler) (int64, error) { func (o *insertSet) Insert(md Modeler) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) {
return id, err return id, err
} }
if id > 0 { if id > 0 {
if o.mi.fields.auto != nil { if o.mi.fields.pk.auto {
ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
return id, nil return id, nil
} }
func (o *insertSet) Close() error { func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
}
o.closed = true o.closed = true
return o.stmt.Close() return o.stmt.Close()
} }

View File

@ -15,47 +15,43 @@ type querySet struct {
orm *orm orm *orm
} }
func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter { var _ QuerySeter = new(querySet)
o = o.clone()
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
} }
o.cond.And(expr, args...) o.cond = o.cond.And(expr, args...)
return o return &o
} }
func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
o = o.clone()
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
} }
o.cond.AndNot(expr, args...) o.cond = o.cond.AndNot(expr, args...)
return o return &o
} }
func (o *querySet) Limit(limit int, args ...int64) QuerySeter { func (o querySet) Limit(limit int, args ...int64) QuerySeter {
o = o.clone()
o.limit = limit o.limit = limit
if len(args) > 0 { if len(args) > 0 {
o.offset = args[0] o.offset = args[0]
} }
return o return &o
} }
func (o *querySet) Offset(offset int64) QuerySeter { func (o querySet) Offset(offset int64) QuerySeter {
o = o.clone()
o.offset = offset o.offset = offset
return o return &o
} }
func (o *querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
o = o.clone()
o.orders = exprs o.orders = exprs
return o return &o
} }
func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
o = o.clone()
var related []string var related []string
if len(params) == 0 { if len(params) == 0 {
o.relDepth = DefaultRelsDepth o.relDepth = DefaultRelsDepth
@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
} }
} }
o.related = related o.related = related
return o
}
func (o querySet) clone() *querySet {
if o.cond != nil {
o.cond = o.cond.Clone()
}
return &o return &o
} }
@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error {
if num > 1 { if num > 1 {
return ErrMultiRows return ErrMultiRows
} }
if num == 0 {
return ErrNoRows
}
return nil return nil
} }

View File

@ -63,6 +63,8 @@ type rawSet struct {
orm *orm orm *orm
} }
var _ RawSeter = new(rawSet)
func (o rawSet) SetArgs(args ...interface{}) RawSeter { func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args o.args = args
return &o return &o
@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) {
return getResult(res) return getResult(res)
} }
func (o *rawSet) Mapper(...interface{}) (int64, error) { func (o *rawSet) QueryRow(...interface{}) error {
//TODO
return nil
}
func (o *rawSet) QueryRows(...interface{}) (int64, error) {
//TODO //TODO
return 0, nil return 0, nil
} }
@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i, _ := range refs { for i, _ := range refs {
var ref string var ref sql.NullString
refs[i] = &ref refs[i] = &ref
} }
} }
@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
case 1: case 1:
params := make(Params, len(cols)) params := make(Params, len(cols))
for i, ref := range refs { for i, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params[cols[i]] = value params[cols[i]] = value.String
} }
maps = append(maps, params) maps = append(maps, params)
case 2: case 2:
params := make(ParamsList, 0, len(cols)) params := make(ParamsList, 0, len(cols))
for _, ref := range refs { for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params = append(params, value) params = append(params, value.String)
} }
lists = append(lists, params) lists = append(lists, params)
case 3: case 3:
for _, ref := range refs { for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
list = append(list, value) list = append(list, value.String)
} }
} }

688
orm/orm_test.go Normal file
View File

@ -0,0 +1,688 @@
package orm
import (
"bytes"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"
)
type T_Code int
const (
// =
T_Equal T_Code = iota
// <
T_Less
// >
T_Large
// elment in slice/array
// T_In
// key exists in map
// T_KeyExist
// index != -1
// T_Contain
// index == 0
// T_StartWith
// index == len(x) - 1
// T_EndWith
)
func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err error, ok bool) {
if len(args) == 0 {
return fmt.Errorf("miss args"), false
}
b := args[0]
arg := argAny(args)
switch o {
case T_Equal:
switch v := a.(type) {
case reflect.Kind:
ok = reflect.ValueOf(b).Kind() == v
case time.Time:
if v2, vo := b.(time.Time); vo {
if arg.Get(1) != nil {
format := ToStr(arg.Get(1))
ok = v.Format(format) == v2.Format(format)
} else {
err = fmt.Errorf("compare datetime miss format")
goto wrongArg
}
}
default:
ok = ToStr(a) == ToStr(b)
}
ok = is && ok || !is && !ok
if !ok {
if is {
err = fmt.Errorf("should: a == b, a = `%v`, b = `%v`", a, b)
} else {
err = fmt.Errorf("should: a != b, a = `%v`, b = `%v`", a, b)
}
}
case T_Less, T_Large:
as := ToStr(a)
bs := ToStr(b)
f1, er := StrTo(as).Float64()
if er != nil {
err = fmt.Errorf("wrong type need numeric: `%v`", a)
goto wrongArg
}
f2, er := StrTo(bs).Float64()
if er != nil {
err = fmt.Errorf("wrong type need numeric: `%v`", b)
goto wrongArg
}
var opts []string
if o == T_Less {
opts = []string{"<", ">="}
ok = f1 < f2
} else {
opts = []string{">", "<="}
ok = f1 > f2
}
ok = is && ok || !is && !ok
if !ok {
if is {
err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[0], f1, f2)
} else {
err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[1], f1, f2)
}
}
}
wrongArg:
if err != nil {
return err, false
}
return nil, true
}
func AssertIs(a interface{}, o T_Code, args ...interface{}) error {
if err, ok := ValuesCompare(true, a, o, args...); ok == false {
return err
}
return nil
}
func AssertNot(a interface{}, o T_Code, args ...interface{}) error {
if err, ok := ValuesCompare(false, a, o, args...); ok == false {
return err
}
return nil
}
func getCaller(skip int) string {
pc, file, line, _ := runtime.Caller(skip)
fun := runtime.FuncForPC(pc)
_, fn := filepath.Split(file)
data, err := ioutil.ReadFile(file)
code := ""
if err == nil {
lines := bytes.Split(data, []byte{'\n'})
code = strings.TrimSpace(string(lines[line-1]))
}
funName := fun.Name()
if i := strings.LastIndex(funName, "."); i > -1 {
funName = funName[i+1:]
}
return fmt.Sprintf("%s:%d: %s: %s", fn, line, funName, code)
}
func throwFail(t *testing.T, err error, args ...interface{}) {
if err != nil {
params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
params = append(params, args...)
t.Error(params...)
t.Fail()
}
}
func throwFailNow(t *testing.T, err error, args ...interface{}) {
if err != nil {
params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
params = append(params, args...)
t.Error(params...)
t.FailNow()
}
}
func TestCRUD(t *testing.T) {
profile := NewProfile()
profile.Age = 30
profile.Money = 1234.12
id, err := dORM.Insert(profile)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
user := NewUser()
user.UserName = "slene"
user.Email = "vslene@gmail.com"
user.Password = "pass"
user.Status = 3
user.IsStaff = true
user.IsActive = true
id, err = dORM.Insert(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
u := &User{Id: user.Id}
err = dORM.Read(u)
throwFailNow(t, err)
throwFailNow(t, AssertIs(u.UserName, T_Equal, "slene"))
throwFailNow(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com"))
throwFailNow(t, AssertIs(u.Password, T_Equal, "pass"))
throwFailNow(t, AssertIs(u.Status, T_Equal, 3))
throwFailNow(t, AssertIs(u.IsStaff, T_Equal, true))
throwFailNow(t, AssertIs(u.IsActive, T_Equal, true))
throwFailNow(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
throwFailNow(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
user.UserName = "astaxie"
user.Profile = profile
num, err := dORM.Update(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, T_Equal, 1))
u = &User{Id: user.Id}
err = dORM.Read(u)
throwFailNow(t, err)
throwFailNow(t, AssertIs(u.UserName, T_Equal, "astaxie"))
throwFailNow(t, AssertIs(u.Profile.Id, T_Equal, profile.Id))
num, err = dORM.Delete(profile)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, T_Equal, 1))
u = &User{Id: user.Id}
err = dORM.Read(u)
throwFailNow(t, err)
throwFailNow(t, AssertIs(true, T_Equal, u.Profile == nil))
num, err = dORM.Delete(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, T_Equal, 1))
u = &User{Id: 100}
err = dORM.Read(u)
throwFailNow(t, AssertIs(err, T_Equal, ErrNoRows))
}
func TestInsertTestData(t *testing.T) {
var users []*User
profile := NewProfile()
profile.Age = 28
profile.Money = 1234.12
id, err := dORM.Insert(profile)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
user := NewUser()
user.UserName = "slene"
user.Email = "vslene@gmail.com"
user.Password = "pass"
user.Status = 1
user.IsStaff = false
user.IsActive = true
user.Profile = profile
users = append(users, user)
id, err = dORM.Insert(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
profile = NewProfile()
profile.Age = 30
profile.Money = 4321.09
id, err = dORM.Insert(profile)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
user = NewUser()
user.UserName = "astaxie"
user.Email = "astaxie@gmail.com"
user.Password = "password"
user.Status = 2
user.IsStaff = true
user.IsActive = false
user.Profile = profile
users = append(users, user)
id, err = dORM.Insert(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
user = NewUser()
user.UserName = "nobody"
user.Email = "nobody@gmail.com"
user.Password = "nobody"
user.Status = 3
user.IsStaff = false
user.IsActive = false
users = append(users, user)
id, err = dORM.Insert(user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
tags := []*Tag{
&Tag{Name: "golang"},
&Tag{Name: "example"},
&Tag{Name: "format"},
&Tag{Name: "c++"},
}
posts := []*Post{
&Post{User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory resultJava programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`},
&Post{User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
&Post{User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`},
&Post{User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
The programand web servergodoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`},
}
comments := []*Comment{
&Comment{Post: posts[0], Content: "a comment"},
&Comment{Post: posts[1], Content: "yes"},
&Comment{Post: posts[1]},
&Comment{Post: posts[1]},
&Comment{Post: posts[2]},
&Comment{Post: posts[2]},
}
for _, tag := range tags {
id, err := dORM.Insert(tag)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
}
for _, post := range posts {
id, err := dORM.Insert(post)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
// dORM.M2mAdd(post, "tags", post.Tags)
}
for _, comment := range comments {
id, err := dORM.Insert(comment)
throwFailNow(t, err)
throwFailNow(t, AssertIs(id, T_Large, 0))
}
}
func TestExpr(t *testing.T) {
qs := dORM.QueryTable("User")
qs = dORM.QueryTable("user")
num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
}
func TestOperators(t *testing.T) {
qs := dORM.QueryTable("user")
num, err := qs.Filter("user_name", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("user_name__exact", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("user_name__iexact", "Slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("user_name__contains", "e").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("user_name__contains", "E").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 0))
num, err = qs.Filter("user_name__icontains", "E").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("user_name__icontains", "E").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("status__gt", 1).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("status__gte", 1).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
num, err = qs.Filter("status__lt", 3).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("status__lte", 3).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
num, err = qs.Filter("user_name__startswith", "s").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("user_name__startswith", "S").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 0))
num, err = qs.Filter("user_name__istartswith", "S").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("user_name__endswith", "e").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("user_name__endswith", "E").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 0))
num, err = qs.Filter("user_name__iendswith", "E").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Filter("profile__isnull", true).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("status__in", 1, 2).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
}
func TestAll(t *testing.T) {
var users []*User
qs := dORM.QueryTable("user")
num, err := qs.All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
qs = dORM.QueryTable("user")
num, err = qs.Filter("user_name", "nothing").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 0))
}
func TestOne(t *testing.T) {
var user User
qs := dORM.QueryTable("user")
err := qs.One(&user)
throwFail(t, AssertIs(err, T_Equal, ErrMultiRows))
err = qs.Filter("user_name", "nothing").One(&user)
throwFail(t, AssertIs(err, T_Equal, ErrNoRows))
}
func TestValues(t *testing.T) {
var maps []Params
qs := dORM.QueryTable("user")
num, err := qs.Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene"))
throwFail(t, AssertIs(maps[2]["Profile"], T_Equal, nil))
}
num, err = qs.Values(&maps, "UserName", "Profile__Age")
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene"))
throwFail(t, AssertIs(maps[0]["Profile__Age"], T_Equal, 28))
throwFail(t, AssertIs(maps[2]["Profile__Age"], T_Equal, nil))
}
}
func TestValuesList(t *testing.T) {
var list []ParamsList
qs := dORM.QueryTable("user")
num, err := qs.ValuesList(&list)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(list[0][1], T_Equal, "slene"))
throwFail(t, AssertIs(list[2][9], T_Equal, nil))
}
num, err = qs.ValuesList(&list, "UserName", "Profile__Age")
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(list[0][0], T_Equal, "slene"))
throwFail(t, AssertIs(list[0][1], T_Equal, 28))
throwFail(t, AssertIs(list[2][1], T_Equal, nil))
}
}
func TestValuesFlat(t *testing.T) {
var list ParamsList
qs := dORM.QueryTable("user")
num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(list[0], T_Equal, "slene"))
throwFail(t, AssertIs(list[1], T_Equal, "astaxie"))
throwFail(t, AssertIs(list[2], T_Equal, "nobody"))
}
}
func TestRelatedSel(t *testing.T) {
qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("profile__age__gt", 28).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("profile__user__profile__age__gt", 28).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
var user User
err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
throwFail(t, AssertNot(user.Profile, T_Equal, nil))
if user.Profile != nil {
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
}
err = qs.Filter("user_name", "slene").RelatedSel().One(&user)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
throwFail(t, AssertNot(user.Profile, T_Equal, nil))
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
if user.Profile != nil {
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
}
err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
throwFail(t, AssertIs(num, T_Equal, 1))
throwFail(t, AssertIs(user.Profile, T_Equal, nil))
qs = dORM.QueryTable("user_profile")
num, err = qs.Filter("user__username", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
}
func TestSetCond(t *testing.T) {
cond := NewCondition()
cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
qs := dORM.QueryTable("user")
num, err := qs.SetCond(cond1).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene"))
num, err = qs.SetCond(cond2).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
}
func TestLimit(t *testing.T) {
var posts []*Post
qs := dORM.QueryTable("post")
num, err := qs.Limit(1).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Limit(-1).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 4))
num, err = qs.Limit(-1, 2).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
num, err = qs.Limit(0, 2).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
}
func TestOffset(t *testing.T) {
var posts []*Post
qs := dORM.QueryTable("post")
num, err := qs.Limit(1).Offset(2).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Offset(2).All(&posts)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
}
func TestOrderBy(t *testing.T) {
qs := dORM.QueryTable("user")
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.OrderBy("status").Filter("user_name", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
}
func TestPrepareInsert(t *testing.T) {
qs := dORM.QueryTable("user")
i, err := qs.PrepareInsert()
throwFail(t, err)
var user User
user.UserName = "testing1"
num, err := i.Insert(&user)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Large, 0))
user.UserName = "testing2"
num, err = i.Insert(&user)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Large, 0))
num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 2))
err = i.Close()
throwFail(t, err)
err = i.Close()
throwFail(t, AssertIs(err, T_Equal, ErrStmtClosed))
}
func TestRaw(t *testing.T) {
switch dORM.Driver().Type() {
case DR_MySQL:
num, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
var maps []Params
num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
if num == 1 {
throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene"))
}
var lists []ParamsList
num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).ValuesList(&lists)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
if num == 1 {
throwFail(t, AssertIs(lists[0][0], T_Equal, "slene"))
}
var list ParamsList
num, err = dORM.Raw("SELECT profile_id FROM user ORDER BY id ASC").ValuesFlat(&list)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(list[0], T_Equal, "2"))
throwFail(t, AssertIs(list[1], T_Equal, "3"))
throwFail(t, AssertIs(list[2], T_Equal, ""))
}
}
}
func TestUpdate(t *testing.T) {
qs := dORM.QueryTable("user")
num, err := qs.Filter("user_name", "slene").Update(Params{
"is_staff": true,
})
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
}
func TestDelete(t *testing.T) {
qs := dORM.QueryTable("user_profile")
num, err := qs.Filter("user__user_name", "slene").Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
qs = dORM.QueryTable("user")
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
}
func TestTransaction(t *testing.T) {
}

View File

@ -5,6 +5,11 @@ import (
"reflect" "reflect"
) )
type Driver interface {
Name() string
Type() DriverType
}
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -26,12 +31,16 @@ type Ormer interface {
Insert(Modeler) (int64, error) Insert(Modeler) (int64, error)
Update(Modeler) (int64, error) Update(Modeler) (int64, error)
Delete(Modeler) (int64, error) Delete(Modeler) (int64, error)
M2mAdd(Modeler, string, ...interface{}) (int64, error)
M2mDel(Modeler, string, ...interface{}) (int64, error)
LoadRel(Modeler, string) (int64, error)
QueryTable(interface{}) QuerySeter QueryTable(interface{}) QuerySeter
Using(string) error Using(string) error
Begin() error Begin() error
Commit() error Commit() error
Rollback() error Rollback() error
Raw(string, ...interface{}) RawSeter Raw(string, ...interface{}) RawSeter
Driver() Driver
} }
type Inserter interface { type Inserter interface {
@ -42,16 +51,15 @@ type Inserter interface {
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Limit(int, ...int64) QuerySeter Limit(int, ...int64) QuerySeter
Offset(int64) QuerySeter Offset(int64) QuerySeter
OrderBy(...string) QuerySeter OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter RelatedSel(...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Count() (int64, error) Count() (int64, error)
Update(Params) (int64, error) Update(Params) (int64, error)
Delete() (int64, error) Delete() (int64, error)
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
All(interface{}) (int64, error) All(interface{}) (int64, error)
One(Modeler) error One(Modeler) error
Values(*[]Params, ...string) (int64, error) Values(*[]Params, ...string) (int64, error)
@ -60,12 +68,15 @@ type QuerySeter interface {
} }
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (int64, error)
Close() error Close() error
} }
type RawSeter interface { type RawSeter interface {
Exec() (int64, error) Exec() (int64, error)
Mapper(...interface{}) (int64, error) QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error) Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error) ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error) ValuesFlat(*ParamsList) (int64, error)

View File

@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) {
return return
} }
type argAny []interface{}
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err