1
0
mirror of https://github.com/astaxie/beego.git synced 2024-12-22 08:40:50 +00:00

orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.

This commit is contained in:
slene 2013-09-12 19:04:39 +08:00
parent 55fe3ba52f
commit 3745bb7279
6 changed files with 346 additions and 295 deletions

105
orm/db.go
View File

@ -49,10 +49,17 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
for _, column := range cols {
var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil {
column = fi.column
} else {
panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
}
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
}
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true, tz)
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
if err != nil {
return 0, err
}
@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string
var args []interface{}
// if specify cols length > 0, then use it for where condition.
if len(cols) > 0 {
var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
if err != nil {
return err
}
} else {
// default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
}
whereCols = append(whereCols, pkColumn)
args = append(args, pkValue)
}
Q := d.ins.TableQuote()
@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
sels := strings.Join(mi.fields.dbcols, sep)
colsNum := len(mi.fields.dbcols)
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
wheres := strings.Join(whereCols, sep)
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
refs := make([]interface{}, colsNum)
for i, _ := range refs {
@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, pkValue)
row := q.QueryRow(query, args...)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
}
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true, tz)
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
if err != nil {
return 0, err
}
@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
// if specify cols length is zero, then commit all columns.
if len(cols) == 0 {
cols = mi.fields.dbcols
}
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
if err != nil {
return 0, err
}
@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
Q := d.ins.TableQuote()
var tCols []string
if len(cols) > 0 {
hasRel := len(qs.related) > 0 || qs.relDepth > 0
tCols = make([]string, 0, len(cols))
var maps map[string]bool
if hasRel {
maps = make(map[string]bool)
}
for _, col := range cols {
if fi, ok := mi.fields.GetByAny(col); ok {
tCols = append(tCols, fi.column)
if hasRel {
maps[fi.column] = true
}
} else {
panic(fmt.Sprintf("wrong field/column name `%s`", col))
}
}
if hasRel {
for _, fi := range mi.fields.fieldsDB {
if fi.fieldType&IsRelField > 0 {
if maps[fi.column] == false {
tCols = append(tCols, fi.column)
}
}
}
}
} else {
tCols = mi.fields.dbcols
}
colsNum := len(tCols)
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql()
colsNum := len(mi.fields.dbcols)
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
for _, tbl := range tables.tables {
if tbl.sel {
colsNum += len(tbl.mi.fields.dbcols)
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
}
}
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
d.ins.ReplaceMarks(&query)
@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cacheM := make(map[string]*modelInfo)
trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
trefs = refs[len(mi.fields.dbcols):]
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
trefs = refs[len(tCols):]
for _, tbl := range tables.tables {
if tbl.sel {

View File

@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
}
func (o *orm) Read(md interface{}) error {
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return err
}
@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) {
return id, nil
}
func (o *orm) Update(md interface{}) (int64, error) {
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}

View File

@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
func (o *querySet) One(container interface{}) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -20,9 +20,9 @@ type Fielder interface {
}
type Ormer interface {
Read(interface{}) error
Read(interface{}, ...string) error
Insert(interface{}) (int64, error)
Update(interface{}) (int64, error)
Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error)
M2mAdd(interface{}, string, ...interface{}) (int64, error)
M2mDel(interface{}, string, ...interface{}) (int64, error)
@ -53,8 +53,8 @@ type QuerySeter interface {
Update(Params) (int64, error)
Delete() (int64, error)
PrepareInsert() (Inserter, error)
All(interface{}) (int64, error)
One(interface{}) error
All(interface{}, ...string) (int64, error)
One(interface{}, ...string) error
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error)
@ -111,12 +111,12 @@ type txEnder interface {
}
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)

View File

@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) {
return int64(v), err
}
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err