mirror of
https://github.com/astaxie/beego.git
synced 2025-07-03 00:10:20 +00:00
some fix / add test
This commit is contained in:
164
orm/db.go
164
orm/db.go
@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) {
|
||||
|
||||
switch {
|
||||
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
|
||||
c1 = jt.fi.mi.fields.pk[0].column
|
||||
c1 = jt.fi.mi.fields.pk.column
|
||||
for _, ffi := range jt.mi.fields.fieldsRel {
|
||||
if jt.fi.mi == ffi.relModelInfo {
|
||||
c2 = ffi.column
|
||||
@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) {
|
||||
}
|
||||
default:
|
||||
c1 = jt.fi.column
|
||||
c2 = jt.fi.relModelInfo.fields.pk[0].column
|
||||
c2 = jt.fi.relModelInfo.fields.pk.column
|
||||
|
||||
if jt.fi.reverse {
|
||||
c1 = jt.mi.fields.pk[0].column
|
||||
c1 = jt.mi.fields.pk.column
|
||||
c2 = jt.fi.reverseFieldInfo.column
|
||||
}
|
||||
}
|
||||
@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
|
||||
if fi.reverseFieldInfo.fieldType == RelManyToMany {
|
||||
mmi = fi.reverseFieldInfo.relThroughModelInfo
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
jt, _ := d.add(names, mmi, fi, fi.null == false)
|
||||
@ -434,40 +436,36 @@ type dbBase struct {
|
||||
ins dbBaser
|
||||
}
|
||||
|
||||
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) {
|
||||
exist := true
|
||||
columns := make([]string, 0, len(mi.fields.pk))
|
||||
values := make([]interface{}, 0, len(mi.fields.pk))
|
||||
for _, fi := range mi.fields.pk {
|
||||
v := ind.Field(fi.fieldIndex)
|
||||
if fi.fieldType&IsIntegerField > 0 {
|
||||
vu := v.Int()
|
||||
if exist {
|
||||
exist = vu > 0
|
||||
}
|
||||
values = append(values, vu)
|
||||
} else {
|
||||
vu := v.String()
|
||||
if exist {
|
||||
exist = vu != ""
|
||||
}
|
||||
values = append(values, vu)
|
||||
}
|
||||
columns = append(columns, fi.column)
|
||||
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
||||
|
||||
fi := mi.fields.pk
|
||||
|
||||
v := ind.Field(fi.fieldIndex)
|
||||
if fi.fieldType&IsIntegerField > 0 {
|
||||
vu := v.Int()
|
||||
exist = vu > 0
|
||||
value = vu
|
||||
} else {
|
||||
vu := v.String()
|
||||
exist = vu != ""
|
||||
value = vu
|
||||
}
|
||||
return columns, values, exist
|
||||
|
||||
column = fi.column
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
|
||||
_, pkValues, _ := d.existPk(mi, ind)
|
||||
_, pkValue, _ := d.existPk(mi, ind)
|
||||
for _, column := range mi.fields.orders {
|
||||
fi := mi.fields.columns[column]
|
||||
if fi.dbcol == false || fi.auto && skipAuto {
|
||||
continue
|
||||
}
|
||||
var value interface{}
|
||||
if i, ok := mi.fields.pk.Exist(fi); ok {
|
||||
value = pkValues[i]
|
||||
if fi.pk {
|
||||
value = pkValue
|
||||
} else {
|
||||
field := ind.Field(fi.fieldIndex)
|
||||
if fi.isFielder {
|
||||
@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
|
||||
if field.IsNil() {
|
||||
value = nil
|
||||
} else {
|
||||
_, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field))
|
||||
if fok {
|
||||
value = fvalues[0]
|
||||
if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
|
||||
value = vu
|
||||
} else {
|
||||
value = nil
|
||||
}
|
||||
@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i
|
||||
}
|
||||
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
|
||||
pkNames, pkValues, ok := d.existPk(mi, ind)
|
||||
pkColumn, pkValue, ok := d.existPk(mi, ind)
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
|
||||
pkColumns := strings.Join(pkNames, "` = ? AND `")
|
||||
|
||||
sels := strings.Join(mi.fields.dbcols, "`, `")
|
||||
colsNum := len(mi.fields.dbcols)
|
||||
|
||||
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns)
|
||||
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
|
||||
|
||||
refs := make([]interface{}, colsNum)
|
||||
for i, _ := range refs {
|
||||
@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
row := q.QueryRow(query, pkValues...)
|
||||
row := q.QueryRow(query, pkValue)
|
||||
if err := row.Scan(refs...); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
elm := reflect.New(mi.addrField.Elem().Type())
|
||||
@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
||||
}
|
||||
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
|
||||
pkNames, pkValues, ok := d.existPk(mi, ind)
|
||||
pkName, pkValue, ok := d.existPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
||||
return 0, err
|
||||
}
|
||||
|
||||
pkColumns := strings.Join(pkNames, "` = ? AND `")
|
||||
setColumns := strings.Join(setNames, "` = ?, `")
|
||||
|
||||
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns)
|
||||
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
|
||||
|
||||
setValues = append(setValues, pkValues...)
|
||||
setValues = append(setValues, pkValue)
|
||||
|
||||
if res, err := q.Exec(query, setValues...); err == nil {
|
||||
return res.RowsAffected()
|
||||
@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
||||
}
|
||||
|
||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
|
||||
names, values, ok := d.existPk(mi, ind)
|
||||
pkName, pkValue, ok := d.existPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
|
||||
columns := strings.Join(names, "` = ? AND `")
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
|
||||
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
|
||||
|
||||
if res, err := q.Exec(query, values...); err == nil {
|
||||
if res, err := q.Exec(query, pkValue); err == nil {
|
||||
|
||||
num, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
@ -660,15 +655,13 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
||||
}
|
||||
|
||||
if num > 0 {
|
||||
if mi.fields.auto != nil {
|
||||
ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
|
||||
if mi.fields.pk.auto {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
|
||||
}
|
||||
|
||||
if len(names) == 1 {
|
||||
err := d.deleteRels(q, mi, values)
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
err := d.deleteRels(q, mi, []interface{}{pkValue})
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -683,12 +676,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
columns := make([]string, 0, len(params))
|
||||
values := make([]interface{}, 0, len(params))
|
||||
for col, val := range params {
|
||||
column := snakeString(col)
|
||||
if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false {
|
||||
panic(fmt.Sprintf("wrong field/column name `%s`", column))
|
||||
if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
|
||||
panic(fmt.Sprintf("wrong field/column name `%s`", col))
|
||||
} else {
|
||||
columns = append(columns, fi.column)
|
||||
values = append(values, val)
|
||||
}
|
||||
columns = append(columns, column)
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
if len(columns) == 0 {
|
||||
@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
|
||||
fi = fi.reverseFieldInfo
|
||||
switch fi.onDelete {
|
||||
case od_CASCADE:
|
||||
cond := NewCondition()
|
||||
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||
_, err := d.DeleteBatch(q, nil, fi.mi, cond)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case od_SET_DEFAULT, od_SET_NULL:
|
||||
cond := NewCondition()
|
||||
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||
params := Params{fi.column: nil}
|
||||
if fi.onDelete == od_SET_DEFAULT {
|
||||
params[fi.column] = fi.initial.String()
|
||||
@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
where, args := tables.getCondSql(cond, false)
|
||||
join := tables.getJoinSql()
|
||||
|
||||
colsNum := len(mi.fields.pk)
|
||||
cols := make([]string, colsNum)
|
||||
for i, fi := range mi.fields.pk {
|
||||
cols[i] = fi.column
|
||||
}
|
||||
colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
|
||||
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
|
||||
cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
|
||||
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
|
||||
|
||||
var rs *sql.Rows
|
||||
if r, err := q.Query(query, args...); err != nil {
|
||||
@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
rs = r
|
||||
}
|
||||
|
||||
refs := make([]interface{}, colsNum)
|
||||
for i, _ := range refs {
|
||||
var ref interface{}
|
||||
refs[i] = &ref
|
||||
}
|
||||
var ref interface{}
|
||||
|
||||
args = make([]interface{}, 0)
|
||||
cnt := 0
|
||||
for rs.Next() {
|
||||
if err := rs.Scan(refs...); err != nil {
|
||||
if err := rs.Scan(&ref); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, ref := range refs {
|
||||
args = append(args, reflect.ValueOf(ref).Elem().Interface())
|
||||
}
|
||||
args = append(args, reflect.ValueOf(ref).Interface())
|
||||
cnt++
|
||||
}
|
||||
|
||||
@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if colsNum > 1 {
|
||||
columns := strings.Join(cols, "` = ? AND `")
|
||||
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
|
||||
} else {
|
||||
var sql string
|
||||
sql, args = d.ins.GetOperatorSql(mi, "in", args)
|
||||
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
|
||||
}
|
||||
sql, args := d.ins.GetOperatorSql(mi, "in", args)
|
||||
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
|
||||
|
||||
if res, err := q.Exec(query, args...); err == nil {
|
||||
num, err := res.RowsAffected()
|
||||
@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if colsNum == 1 && num > 0 {
|
||||
if num > 0 {
|
||||
err := d.deleteRels(q, mi, args)
|
||||
if err != nil {
|
||||
return num, err
|
||||
@ -980,14 +954,12 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
|
||||
copy(params, args)
|
||||
sql := ""
|
||||
for i, arg := range args {
|
||||
if len(mi.fields.pk) == 1 {
|
||||
if md, ok := arg.(Modeler); ok {
|
||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||
if _, values, exist := d.existPk(mi, ind); exist {
|
||||
arg = values[0]
|
||||
} else {
|
||||
panic(fmt.Sprintf("`%s` need a valid args value", operator))
|
||||
}
|
||||
if md, ok := arg.(Modeler); ok {
|
||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||
if _, vu, exist := d.existPk(mi, ind); exist {
|
||||
arg = vu
|
||||
} else {
|
||||
panic(fmt.Sprintf("`%s` need a valid args value", operator))
|
||||
}
|
||||
}
|
||||
params[i] = arg
|
||||
@ -1175,7 +1147,7 @@ setValue:
|
||||
value = v
|
||||
}
|
||||
case fieldType&IsRelField > 0:
|
||||
fieldType = fi.relModelInfo.fields.pk[0].fieldType
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
goto setValue
|
||||
}
|
||||
|
||||
@ -1236,12 +1208,12 @@ setValue:
|
||||
}
|
||||
case fieldType&IsRelField > 0:
|
||||
if value != nil {
|
||||
fieldType = fi.relModelInfo.fields.pk[0].fieldType
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
|
||||
md := mf.Interface().(Modeler)
|
||||
md.Init(md)
|
||||
field.Set(mf)
|
||||
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
|
||||
f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
|
||||
field = &f
|
||||
goto setValue
|
||||
}
|
||||
|
Reference in New Issue
Block a user