1
0
mirror of https://github.com/astaxie/beego.git synced 2025-07-03 00:10:20 +00:00

some fix / add test

This commit is contained in:
slene
2013-08-07 19:11:44 +08:00
parent 10f4e822c3
commit 46668b811f
15 changed files with 1082 additions and 222 deletions

164
orm/db.go
View File

@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) {
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk[0].column
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) {
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk[0].column
c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse {
c1 = jt.mi.fields.pk[0].column
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
}
}
@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
default:
return
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
@ -434,40 +436,36 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) {
exist := true
columns := make([]string, 0, len(mi.fields.pk))
values := make([]interface{}, 0, len(mi.fields.pk))
for _, fi := range mi.fields.pk {
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
if exist {
exist = vu > 0
}
values = append(values, vu)
} else {
vu := v.String()
if exist {
exist = vu != ""
}
values = append(values, vu)
}
columns = append(columns, fi.column)
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = vu > 0
value = vu
} else {
vu := v.String()
exist = vu != ""
value = vu
}
return columns, values, exist
column = fi.column
return
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValues, _ := d.existPk(mi, ind)
_, pkValue, _ := d.existPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
var value interface{}
if i, ok := mi.fields.pk.Exist(fi); ok {
value = pkValues[i]
if fi.pk {
value = pkValue
} else {
field := ind.Field(fi.fieldIndex)
if fi.isFielder {
@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if field.IsNil() {
value = nil
} else {
_, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field))
if fok {
value = fvalues[0]
if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
value = vu
} else {
value = nil
}
@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
pkNames, pkValues, ok := d.existPk(mi, ind)
pkColumn, pkValue, ok := d.existPk(mi, ind)
if ok == false {
return ErrMissPK
}
pkColumns := strings.Join(pkNames, "` = ? AND `")
sels := strings.Join(mi.fields.dbcols, "`, `")
colsNum := len(mi.fields.dbcols)
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns)
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
refs := make([]interface{}, colsNum)
for i, _ := range refs {
@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
refs[i] = &ref
}
row := q.QueryRow(query, pkValues...)
row := q.QueryRow(query, pkValue)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err
} else {
elm := reflect.New(mi.addrField.Elem().Type())
@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkNames, pkValues, ok := d.existPk(mi, ind)
pkName, pkValue, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
return 0, err
}
pkColumns := strings.Join(pkNames, "` = ? AND `")
setColumns := strings.Join(setNames, "` = ?, `")
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns)
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
setValues = append(setValues, pkValues...)
setValues = append(setValues, pkValue)
if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected()
@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, ok := d.existPk(mi, ind)
pkName, pkValue, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
columns := strings.Join(names, "` = ? AND `")
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
if res, err := q.Exec(query, values...); err == nil {
if res, err := q.Exec(query, pkValue); err == nil {
num, err := res.RowsAffected()
if err != nil {
@ -660,15 +655,13 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
if num > 0 {
if mi.fields.auto != nil {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
if mi.fields.pk.auto {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
if len(names) == 1 {
err := d.deleteRels(q, mi, values)
if err != nil {
return num, err
}
err := d.deleteRels(q, mi, []interface{}{pkValue})
if err != nil {
return num, err
}
}
@ -683,12 +676,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
column := snakeString(col)
if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false {
panic(fmt.Sprintf("wrong field/column name `%s`", column))
if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
panic(fmt.Sprintf("wrong field/column name `%s`", col))
} else {
columns = append(columns, fi.column)
values = append(values, val)
}
columns = append(columns, column)
values = append(values, val)
}
if len(columns) == 0 {
@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond)
if err != nil {
return err
}
case od_SET_DEFAULT, od_SET_NULL:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String()
@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql()
colsNum := len(mi.fields.pk)
cols := make([]string, colsNum)
for i, fi := range mi.fields.pk {
cols[i] = fi.column
}
colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r
}
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref interface{}
refs[i] = &ref
}
var ref interface{}
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
if err := rs.Scan(refs...); err != nil {
if err := rs.Scan(&ref); err != nil {
return 0, err
}
for _, ref := range refs {
args = append(args, reflect.ValueOf(ref).Elem().Interface())
}
args = append(args, reflect.ValueOf(ref).Interface())
cnt++
}
@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
if colsNum > 1 {
columns := strings.Join(cols, "` = ? AND `")
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
} else {
var sql string
sql, args = d.ins.GetOperatorSql(mi, "in", args)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
}
sql, args := d.ins.GetOperatorSql(mi, "in", args)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
if res, err := q.Exec(query, args...); err == nil {
num, err := res.RowsAffected()
@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, err
}
if colsNum == 1 && num > 0 {
if num > 0 {
err := d.deleteRels(q, mi, args)
if err != nil {
return num, err
@ -980,14 +954,12 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
copy(params, args)
sql := ""
for i, arg := range args {
if len(mi.fields.pk) == 1 {
if md, ok := arg.(Modeler); ok {
ind := reflect.Indirect(reflect.ValueOf(md))
if _, values, exist := d.existPk(mi, ind); exist {
arg = values[0]
} else {
panic(fmt.Sprintf("`%s` need a valid args value", operator))
}
if md, ok := arg.(Modeler); ok {
ind := reflect.Indirect(reflect.ValueOf(md))
if _, vu, exist := d.existPk(mi, ind); exist {
arg = vu
} else {
panic(fmt.Sprintf("`%s` need a valid args value", operator))
}
}
params[i] = arg
@ -1175,7 +1147,7 @@ setValue:
value = v
}
case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType
fieldType = fi.relModelInfo.fields.pk.fieldType
goto setValue
}
@ -1236,12 +1208,12 @@ setValue:
}
case fieldType&IsRelField > 0:
if value != nil {
fieldType = fi.relModelInfo.fields.pk[0].fieldType
fieldType = fi.relModelInfo.fields.pk.fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler)
md.Init(md)
field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
field = &f
goto setValue
}