mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 08:20:55 +00:00
orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.
This commit is contained in:
parent
55fe3ba52f
commit
3745bb7279
105
orm/db.go
105
orm/db.go
@ -49,10 +49,17 @@ type dbBase struct {
|
||||
ins dbBaser
|
||||
}
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
_, pkValue, _ := getExistPk(mi, ind)
|
||||
for _, column := range mi.fields.orders {
|
||||
fi := mi.fields.columns[column]
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||
column = fi.column
|
||||
} else {
|
||||
panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
|
||||
}
|
||||
if fi.dbcol == false || fi.auto && skipAuto {
|
||||
continue
|
||||
}
|
||||
@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
_, values, err := d.collectValues(mi, ind, true, true, tz)
|
||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
|
||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||
var whereCols []string
|
||||
var args []interface{}
|
||||
|
||||
// if specify cols length > 0, then use it for where condition.
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// default use pk value as where condtion.
|
||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
whereCols = append(whereCols, pkColumn)
|
||||
args = append(args, pkValue)
|
||||
}
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
sels := strings.Join(mi.fields.dbcols, sep)
|
||||
colsNum := len(mi.fields.dbcols)
|
||||
|
||||
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
|
||||
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
|
||||
wheres := strings.Join(whereCols, sep)
|
||||
|
||||
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
|
||||
|
||||
refs := make([]interface{}, colsNum)
|
||||
for i, _ := range refs {
|
||||
@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
row := q.QueryRow(query, pkValue)
|
||||
row := q.QueryRow(query, args...)
|
||||
if err := row.Scan(refs...); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
}
|
||||
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names, values, err := d.collectValues(mi, ind, true, true, tz)
|
||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
|
||||
|
||||
// if specify cols length is zero, then commit all columns.
|
||||
if len(cols) == 0 {
|
||||
cols = mi.fields.dbcols
|
||||
}
|
||||
|
||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||
|
||||
val := reflect.ValueOf(container)
|
||||
ind := reflect.Indirect(val)
|
||||
@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
var tCols []string
|
||||
if len(cols) > 0 {
|
||||
hasRel := len(qs.related) > 0 || qs.relDepth > 0
|
||||
tCols = make([]string, 0, len(cols))
|
||||
var maps map[string]bool
|
||||
if hasRel {
|
||||
maps = make(map[string]bool)
|
||||
}
|
||||
for _, col := range cols {
|
||||
if fi, ok := mi.fields.GetByAny(col); ok {
|
||||
tCols = append(tCols, fi.column)
|
||||
if hasRel {
|
||||
maps[fi.column] = true
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Sprintf("wrong field/column name `%s`", col))
|
||||
}
|
||||
}
|
||||
if hasRel {
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
if fi.fieldType&IsRelField > 0 {
|
||||
if maps[fi.column] == false {
|
||||
tCols = append(tCols, fi.column)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tCols = mi.fields.dbcols
|
||||
}
|
||||
|
||||
colsNum := len(tCols)
|
||||
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
|
||||
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
|
||||
@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
limit := tables.getLimitSql(mi, offset, rlimit)
|
||||
join := tables.getJoinSql()
|
||||
|
||||
colsNum := len(mi.fields.dbcols)
|
||||
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
|
||||
for _, tbl := range tables.tables {
|
||||
if tbl.sel {
|
||||
colsNum += len(tbl.mi.fields.dbcols)
|
||||
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
|
||||
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
|
||||
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
cacheM := make(map[string]*modelInfo)
|
||||
trefs := refs
|
||||
|
||||
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
|
||||
trefs = refs[len(mi.fields.dbcols):]
|
||||
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
|
||||
trefs = refs[len(tCols):]
|
||||
|
||||
for _, tbl := range tables.tables {
|
||||
if tbl.sel {
|
||||
|
@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
||||
panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
||||
}
|
||||
|
||||
func (o *orm) Read(md interface{}) error {
|
||||
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||
mi, ind := o.getMiInd(md)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (o *orm) Update(md interface{}) (int64, error) {
|
||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
|
||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
|
@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||
return newInsertSet(o.orm, o.mi)
|
||||
}
|
||||
|
||||
func (o *querySet) All(container interface{}) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
|
||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
}
|
||||
|
||||
func (o *querySet) One(container interface{}) error {
|
||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
|
||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
496
orm/orm_test.go
496
orm/orm_test.go
File diff suppressed because it is too large
Load Diff
14
orm/types.go
14
orm/types.go
@ -20,9 +20,9 @@ type Fielder interface {
|
||||
}
|
||||
|
||||
type Ormer interface {
|
||||
Read(interface{}) error
|
||||
Read(interface{}, ...string) error
|
||||
Insert(interface{}) (int64, error)
|
||||
Update(interface{}) (int64, error)
|
||||
Update(interface{}, ...string) (int64, error)
|
||||
Delete(interface{}) (int64, error)
|
||||
M2mAdd(interface{}, string, ...interface{}) (int64, error)
|
||||
M2mDel(interface{}, string, ...interface{}) (int64, error)
|
||||
@ -53,8 +53,8 @@ type QuerySeter interface {
|
||||
Update(Params) (int64, error)
|
||||
Delete() (int64, error)
|
||||
PrepareInsert() (Inserter, error)
|
||||
All(interface{}) (int64, error)
|
||||
One(interface{}) error
|
||||
All(interface{}, ...string) (int64, error)
|
||||
One(interface{}, ...string) error
|
||||
Values(*[]Params, ...string) (int64, error)
|
||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||
ValuesFlat(*ParamsList, string) (int64, error)
|
||||
@ -111,12 +111,12 @@ type txEnder interface {
|
||||
}
|
||||
|
||||
type dbBaser interface {
|
||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
|
||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
|
||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
||||
SupportUpdateJoin() bool
|
||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||
|
10
orm/utils.go
10
orm/utils.go
@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
|
||||
return strconv.ParseFloat(f.String(), 64)
|
||||
}
|
||||
|
||||
func (f StrTo) Int() (int, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int(v), err
|
||||
}
|
||||
|
||||
func (f StrTo) Int8() (int8, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||
return int8(v), err
|
||||
@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) {
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func (f StrTo) Uint() (uint, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint(v), err
|
||||
}
|
||||
|
||||
func (f StrTo) Uint8() (uint8, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||
return uint8(v), err
|
||||
|
Loading…
Reference in New Issue
Block a user