mirror of
https://github.com/astaxie/beego.git
synced 2025-07-02 20:40:17 +00:00
orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.
This commit is contained in:
105
orm/db.go
105
orm/db.go
@ -49,10 +49,17 @@ type dbBase struct {
|
||||
ins dbBaser
|
||||
}
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
_, pkValue, _ := getExistPk(mi, ind)
|
||||
for _, column := range mi.fields.orders {
|
||||
fi := mi.fields.columns[column]
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||
column = fi.column
|
||||
} else {
|
||||
panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
|
||||
}
|
||||
if fi.dbcol == false || fi.auto && skipAuto {
|
||||
continue
|
||||
}
|
||||
@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
_, values, err := d.collectValues(mi, ind, true, true, tz)
|
||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
|
||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||
var whereCols []string
|
||||
var args []interface{}
|
||||
|
||||
// if specify cols length > 0, then use it for where condition.
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// default use pk value as where condtion.
|
||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
whereCols = append(whereCols, pkColumn)
|
||||
args = append(args, pkValue)
|
||||
}
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
sels := strings.Join(mi.fields.dbcols, sep)
|
||||
colsNum := len(mi.fields.dbcols)
|
||||
|
||||
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
|
||||
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
|
||||
wheres := strings.Join(whereCols, sep)
|
||||
|
||||
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
|
||||
|
||||
refs := make([]interface{}, colsNum)
|
||||
for i, _ := range refs {
|
||||
@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
row := q.QueryRow(query, pkValue)
|
||||
row := q.QueryRow(query, args...)
|
||||
if err := row.Scan(refs...); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
}
|
||||
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names, values, err := d.collectValues(mi, ind, true, true, tz)
|
||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
|
||||
|
||||
// if specify cols length is zero, then commit all columns.
|
||||
if len(cols) == 0 {
|
||||
cols = mi.fields.dbcols
|
||||
}
|
||||
|
||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||
|
||||
val := reflect.ValueOf(container)
|
||||
ind := reflect.Indirect(val)
|
||||
@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
var tCols []string
|
||||
if len(cols) > 0 {
|
||||
hasRel := len(qs.related) > 0 || qs.relDepth > 0
|
||||
tCols = make([]string, 0, len(cols))
|
||||
var maps map[string]bool
|
||||
if hasRel {
|
||||
maps = make(map[string]bool)
|
||||
}
|
||||
for _, col := range cols {
|
||||
if fi, ok := mi.fields.GetByAny(col); ok {
|
||||
tCols = append(tCols, fi.column)
|
||||
if hasRel {
|
||||
maps[fi.column] = true
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Sprintf("wrong field/column name `%s`", col))
|
||||
}
|
||||
}
|
||||
if hasRel {
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
if fi.fieldType&IsRelField > 0 {
|
||||
if maps[fi.column] == false {
|
||||
tCols = append(tCols, fi.column)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tCols = mi.fields.dbcols
|
||||
}
|
||||
|
||||
colsNum := len(tCols)
|
||||
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
|
||||
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
|
||||
@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
limit := tables.getLimitSql(mi, offset, rlimit)
|
||||
join := tables.getJoinSql()
|
||||
|
||||
colsNum := len(mi.fields.dbcols)
|
||||
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
|
||||
for _, tbl := range tables.tables {
|
||||
if tbl.sel {
|
||||
colsNum += len(tbl.mi.fields.dbcols)
|
||||
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
|
||||
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
|
||||
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
cacheM := make(map[string]*modelInfo)
|
||||
trefs := refs
|
||||
|
||||
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
|
||||
trefs = refs[len(mi.fields.dbcols):]
|
||||
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
|
||||
trefs = refs[len(tCols):]
|
||||
|
||||
for _, tbl := range tables.tables {
|
||||
if tbl.sel {
|
||||
|
Reference in New Issue
Block a user