1
0
mirror of https://github.com/astaxie/beego.git synced 2025-07-03 01:10:20 +00:00

orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.

This commit is contained in:
slene
2013-09-12 19:04:39 +08:00
parent 55fe3ba52f
commit 3745bb7279
6 changed files with 346 additions and 295 deletions

105
orm/db.go
View File

@ -49,10 +49,17 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
for _, column := range cols {
var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil {
column = fi.column
} else {
panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
}
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
}
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true, tz)
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
if err != nil {
return 0, err
}
@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string
var args []interface{}
// if specify cols length > 0, then use it for where condition.
if len(cols) > 0 {
var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
if err != nil {
return err
}
} else {
// default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
}
whereCols = append(whereCols, pkColumn)
args = append(args, pkValue)
}
Q := d.ins.TableQuote()
@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
sels := strings.Join(mi.fields.dbcols, sep)
colsNum := len(mi.fields.dbcols)
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
wheres := strings.Join(whereCols, sep)
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
refs := make([]interface{}, colsNum)
for i, _ := range refs {
@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, pkValue)
row := q.QueryRow(query, args...)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
}
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true, tz)
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
if err != nil {
return 0, err
}
@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
// if specify cols length is zero, then commit all columns.
if len(cols) == 0 {
cols = mi.fields.dbcols
}
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
if err != nil {
return 0, err
}
@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
Q := d.ins.TableQuote()
var tCols []string
if len(cols) > 0 {
hasRel := len(qs.related) > 0 || qs.relDepth > 0
tCols = make([]string, 0, len(cols))
var maps map[string]bool
if hasRel {
maps = make(map[string]bool)
}
for _, col := range cols {
if fi, ok := mi.fields.GetByAny(col); ok {
tCols = append(tCols, fi.column)
if hasRel {
maps[fi.column] = true
}
} else {
panic(fmt.Sprintf("wrong field/column name `%s`", col))
}
}
if hasRel {
for _, fi := range mi.fields.fieldsDB {
if fi.fieldType&IsRelField > 0 {
if maps[fi.column] == false {
tCols = append(tCols, fi.column)
}
}
}
}
} else {
tCols = mi.fields.dbcols
}
colsNum := len(tCols)
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql()
colsNum := len(mi.fields.dbcols)
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
for _, tbl := range tables.tables {
if tbl.sel {
colsNum += len(tbl.mi.fields.dbcols)
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
}
}
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
d.ins.ReplaceMarks(&query)
@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cacheM := make(map[string]*modelInfo)
trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
trefs = refs[len(mi.fields.dbcols):]
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
trefs = refs[len(tCols):]
for _, tbl := range tables.tables {
if tbl.sel {