mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 13:50:54 +00:00
orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.
This commit is contained in:
parent
55fe3ba52f
commit
3745bb7279
99
orm/db.go
99
orm/db.go
@ -49,10 +49,17 @@ type dbBase struct {
|
|||||||
ins dbBaser
|
ins dbBaser
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
var _ dbBaser = new(dbBase)
|
||||||
|
|
||||||
|
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||||
_, pkValue, _ := getExistPk(mi, ind)
|
_, pkValue, _ := getExistPk(mi, ind)
|
||||||
for _, column := range mi.fields.orders {
|
for _, column := range cols {
|
||||||
fi := mi.fields.columns[column]
|
var fi *fieldInfo
|
||||||
|
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||||
|
column = fi.column
|
||||||
|
} else {
|
||||||
|
panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
|
||||||
|
}
|
||||||
if fi.dbcol == false || fi.auto && skipAuto {
|
if fi.dbcol == false || fi.auto && skipAuto {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
_, values, err := d.collectValues(mi, ind, true, true, tz)
|
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -179,11 +186,26 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error {
|
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||||
|
var whereCols []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
// if specify cols length > 0, then use it for where condition.
|
||||||
|
if len(cols) > 0 {
|
||||||
|
var err error
|
||||||
|
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// default use pk value as where condtion.
|
||||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if ok == false {
|
||||||
return ErrMissPK
|
return ErrMissPK
|
||||||
}
|
}
|
||||||
|
whereCols = append(whereCols, pkColumn)
|
||||||
|
args = append(args, pkValue)
|
||||||
|
}
|
||||||
|
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
sels := strings.Join(mi.fields.dbcols, sep)
|
sels := strings.Join(mi.fields.dbcols, sep)
|
||||||
colsNum := len(mi.fields.dbcols)
|
colsNum := len(mi.fields.dbcols)
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
|
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
|
||||||
|
wheres := strings.Join(whereCols, sep)
|
||||||
|
|
||||||
|
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
|
||||||
|
|
||||||
refs := make([]interface{}, colsNum)
|
refs := make([]interface{}, colsNum)
|
||||||
for i, _ := range refs {
|
for i, _ := range refs {
|
||||||
@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
row := q.QueryRow(query, pkValue)
|
row := q.QueryRow(query, args...)
|
||||||
if err := row.Scan(refs...); err != nil {
|
if err := row.Scan(refs...); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return ErrNoRows
|
return ErrNoRows
|
||||||
@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
names, values, err := d.collectValues(mi, ind, true, true, tz)
|
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if ok == false {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
}
|
}
|
||||||
setNames, setValues, err := d.collectValues(mi, ind, true, false, tz)
|
|
||||||
|
// if specify cols length is zero, then commit all columns.
|
||||||
|
if len(cols) == 0 {
|
||||||
|
cols = mi.fields.dbcols
|
||||||
|
}
|
||||||
|
|
||||||
|
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) {
|
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||||
|
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
|
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
var tCols []string
|
||||||
|
if len(cols) > 0 {
|
||||||
|
hasRel := len(qs.related) > 0 || qs.relDepth > 0
|
||||||
|
tCols = make([]string, 0, len(cols))
|
||||||
|
var maps map[string]bool
|
||||||
|
if hasRel {
|
||||||
|
maps = make(map[string]bool)
|
||||||
|
}
|
||||||
|
for _, col := range cols {
|
||||||
|
if fi, ok := mi.fields.GetByAny(col); ok {
|
||||||
|
tCols = append(tCols, fi.column)
|
||||||
|
if hasRel {
|
||||||
|
maps[fi.column] = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic(fmt.Sprintf("wrong field/column name `%s`", col))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasRel {
|
||||||
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
|
if fi.fieldType&IsRelField > 0 {
|
||||||
|
if maps[fi.column] == false {
|
||||||
|
tCols = append(tCols, fi.column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tCols = mi.fields.dbcols
|
||||||
|
}
|
||||||
|
|
||||||
|
colsNum := len(tCols)
|
||||||
|
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||||
|
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
|
||||||
|
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
|
|
||||||
@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
limit := tables.getLimitSql(mi, offset, rlimit)
|
limit := tables.getLimitSql(mi, offset, rlimit)
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
colsNum := len(mi.fields.dbcols)
|
|
||||||
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
|
||||||
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
|
|
||||||
for _, tbl := range tables.tables {
|
for _, tbl := range tables.tables {
|
||||||
if tbl.sel {
|
if tbl.sel {
|
||||||
colsNum += len(tbl.mi.fields.dbcols)
|
colsNum += len(tbl.mi.fields.dbcols)
|
||||||
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
|
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
|
||||||
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
|
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
cacheM := make(map[string]*modelInfo)
|
cacheM := make(map[string]*modelInfo)
|
||||||
trefs := refs
|
trefs := refs
|
||||||
|
|
||||||
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz)
|
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
|
||||||
trefs = refs[len(mi.fields.dbcols):]
|
trefs = refs[len(tCols):]
|
||||||
|
|
||||||
for _, tbl := range tables.tables {
|
for _, tbl := range tables.tables {
|
||||||
if tbl.sel {
|
if tbl.sel {
|
||||||
|
@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
|||||||
panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *orm) Read(md interface{}) error {
|
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
|
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) {
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *orm) Update(md interface{}) (int64, error) {
|
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md)
|
||||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
|
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
|
@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
|
|||||||
return newInsertSet(o.orm, o.mi)
|
return newInsertSet(o.orm, o.mi)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *querySet) All(container interface{}) (int64, error) {
|
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *querySet) One(container interface{}) error {
|
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
|
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
456
orm/orm_test.go
456
orm/orm_test.go
File diff suppressed because it is too large
Load Diff
14
orm/types.go
14
orm/types.go
@ -20,9 +20,9 @@ type Fielder interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Ormer interface {
|
type Ormer interface {
|
||||||
Read(interface{}) error
|
Read(interface{}, ...string) error
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
Update(interface{}) (int64, error)
|
Update(interface{}, ...string) (int64, error)
|
||||||
Delete(interface{}) (int64, error)
|
Delete(interface{}) (int64, error)
|
||||||
M2mAdd(interface{}, string, ...interface{}) (int64, error)
|
M2mAdd(interface{}, string, ...interface{}) (int64, error)
|
||||||
M2mDel(interface{}, string, ...interface{}) (int64, error)
|
M2mDel(interface{}, string, ...interface{}) (int64, error)
|
||||||
@ -53,8 +53,8 @@ type QuerySeter interface {
|
|||||||
Update(Params) (int64, error)
|
Update(Params) (int64, error)
|
||||||
Delete() (int64, error)
|
Delete() (int64, error)
|
||||||
PrepareInsert() (Inserter, error)
|
PrepareInsert() (Inserter, error)
|
||||||
All(interface{}) (int64, error)
|
All(interface{}, ...string) (int64, error)
|
||||||
One(interface{}) error
|
One(interface{}, ...string) error
|
||||||
Values(*[]Params, ...string) (int64, error)
|
Values(*[]Params, ...string) (int64, error)
|
||||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||||
ValuesFlat(*ParamsList, string) (int64, error)
|
ValuesFlat(*ParamsList, string) (int64, error)
|
||||||
@ -111,12 +111,12 @@ type txEnder interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type dbBaser interface {
|
type dbBaser interface {
|
||||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
|
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
|
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
||||||
SupportUpdateJoin() bool
|
SupportUpdateJoin() bool
|
||||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
||||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||||
|
10
orm/utils.go
10
orm/utils.go
@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
|
|||||||
return strconv.ParseFloat(f.String(), 64)
|
return strconv.ParseFloat(f.String(), 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f StrTo) Int() (int, error) {
|
||||||
|
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||||
|
return int(v), err
|
||||||
|
}
|
||||||
|
|
||||||
func (f StrTo) Int8() (int8, error) {
|
func (f StrTo) Int8() (int8, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||||
return int8(v), err
|
return int8(v), err
|
||||||
@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) {
|
|||||||
return int64(v), err
|
return int64(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f StrTo) Uint() (uint, error) {
|
||||||
|
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||||
|
return uint(v), err
|
||||||
|
}
|
||||||
|
|
||||||
func (f StrTo) Uint8() (uint8, error) {
|
func (f StrTo) Uint8() (uint8, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||||
return uint8(v), err
|
return uint8(v), err
|
||||||
|
Loading…
Reference in New Issue
Block a user