1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 21:00:57 +00:00
Beego/orm/db.go
2013-08-01 15:51:53 +08:00

1409 lines
31 KiB
Go

package orm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
)
const (
format_Date = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05"
)
var (
ErrMissPK = errors.New("missed pk value")
)
var (
operators = map[string]bool{
"exact": true,
"iexact": true,
"contains": true,
"icontains": true,
// "regex": true,
// "iregex": true,
"gt": true,
"gte": true,
"lt": true,
"lte": true,
"startswith": true,
"endswith": true,
"istartswith": true,
"iendswith": true,
"in": true,
// "range": true,
// "year": true,
// "month": true,
// "day": true,
// "week_day": true,
"isnull": true,
// "search": true,
}
operatorsSQL = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
)
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
}
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cansel = true
jtl *dbTable
)
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
jt := t.set(names, mmi, fi, fi.null == false)
jt.jtl = jtl
if fi.reverse {
cansel = false
}
if cansel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
}
}
}
}
func (t *dbTables) getJoinSql() (join string) {
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk[0].column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk[0].column
if jt.fi.reverse {
c1 = jt.mi.fields.pk[0].column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2,
t2, c2, t1, c1)
}
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
mmi = mi
)
num := len(exprs) - 1
names := make([]string, 0)
for i, ex := range exprs {
exist := false
check:
fi, ok := mmi.fields.GetByAny(ex)
if ok {
if num != i {
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
jt.jtl = jtl
jtl = jt
if fi.rel && fi.fieldType == RelManyToMany {
ex = fi.relModelInfo.name
goto check
}
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
ex = fi.reverseFieldInfo.mi.name
goto check
}
exist = true
} else {
if ffi == nil {
index = "T0"
} else {
index = jtl.index
}
column = fi.column
info = fi
if jtl != nil {
name = jtl.name + ExprSep + fi.name
} else {
name = fi.name
}
switch fi.fieldType {
case RelManyToMany, RelReverseMany:
default:
exist = true
}
}
ffi = fi
}
if exist == false {
index = ""
column = ""
name = ""
success = false
return
}
}
success = index != "" && column != ""
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
mi := d.mi
// outFor:
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, column, _, _, suc := d.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
operSql, args := d.base.GetOperatorSql(mi, operator, p.args)
where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql)
params = append(params, args...)
}
}
if sub == false && where != "" {
where = "WHERE " + where
}
return
}
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
}
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
index, column, _, _, suc := d.parseExprs(d.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) {
if limit == 0 {
limit = DefaultRowsLimit
}
if limit < 0 {
// no limit
if offset > 0 {
limits = fmt.Sprintf("LIMIT 18446744073709551615 OFFSET %d", offset)
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}
type dbBase struct {
ins dbBaser
}
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) {
exist := true
columns := make([]string, 0, len(mi.fields.pk))
values := make([]interface{}, 0, len(mi.fields.pk))
for _, fi := range mi.fields.pk {
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
if exist {
exist = vu > 0
}
values = append(values, vu)
} else {
vu := v.String()
if exist {
exist = vu != ""
}
values = append(values, vu)
}
columns = append(columns, fi.column)
}
return columns, values, exist
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValues, _ := d.existPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
var value interface{}
if i, ok := mi.fields.pk.Exist(fi); ok {
value = pkValues[i]
} else {
field := ind.Field(fi.fieldIndex)
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
value = f.RawValue()
} else {
switch fi.fieldType {
case TypeBooleanField:
value = field.Bool()
case TypeCharField, TypeTextField:
value = field.String()
case TypeFloatField, TypeDecimalField:
value = field.Float()
case TypeDateField, TypeDateTimeField:
value = field.Interface()
default:
switch {
case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint()
case fi.fieldType&IsIntegerField > 0:
value = field.Int()
case fi.fieldType&IsRelField > 0:
if field.IsNil() {
value = nil
} else {
_, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field))
if fok {
value = fvalues[0]
} else {
value = nil
}
}
if fi.null == false && value == nil {
return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now()
if fi.fieldType == TypeDateField {
value = timeFormat(tnow, format_Date)
} else {
value = timeFormat(tnow, format_DateTime)
}
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow)
} else {
field.Set(reflect.ValueOf(tnow))
}
}
}
}
columns = append(columns, column)
values = append(values, value)
}
return
}
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (*sql.Stmt, error) {
dbcols := make([]string, 0, len(mi.fields.dbcols))
marks := make([]string, 0, len(mi.fields.dbcols))
for _, fi := range mi.fields.fieldsDB {
if fi.auto == false {
dbcols = append(dbcols, fi.column)
marks = append(marks, "?")
}
}
qmarks := strings.Join(marks, ", ")
columns := strings.Join(dbcols, "`,`")
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
return q.Prepare(query)
}
func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true)
if err != nil {
return 0, err
}
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
pkNames, pkValues, ok := d.existPk(mi, ind)
if ok == false {
return ErrMissPK
}
pkColumns := strings.Join(pkNames, "` = ? AND `")
sels := strings.Join(mi.fields.dbcols, "`, `")
colsNum := len(mi.fields.dbcols)
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns)
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref interface{}
refs[i] = &ref
}
row := q.QueryRow(query, pkValues...)
if err := row.Scan(refs...); err != nil {
return err
} else {
elm := reflect.New(mi.addrField.Elem().Type())
md := elm.Interface().(Modeler)
md.Init(md)
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs)
ind.Set(mind)
}
return nil
}
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
for i, _ := range marks {
marks[i] = "?"
}
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, "`,`")
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
if res, err := q.Exec(query, values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkNames, pkValues, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
setNames, setValues, err := d.collectValues(mi, ind, true, false)
if err != nil {
return 0, err
}
pkColumns := strings.Join(pkNames, "` = ? AND `")
setColumns := strings.Join(setNames, "` = ?, `")
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns)
setValues = append(setValues, pkValues...)
if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
columns := strings.Join(names, "` = ? AND `")
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
if res, err := q.Exec(query, values...); err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
if mi.fields.auto != nil {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
}
if len(names) == 1 {
err := d.deleteRels(q, mi, values)
if err != nil {
return num, err
}
}
}
return num, err
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
column := snakeString(col)
if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false {
panic(fmt.Sprintf("wrong field/column name `%s`", column))
}
columns = append(columns, column)
values = append(values, val)
}
if len(columns) == 0 {
panic("update params cannot empty")
}
tables := newDbTables(mi, d.ins)
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
}
where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql()
query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where)
values = append(values, args...)
if res, err := q.Exec(query, values...); err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond)
if err != nil {
return err
}
case od_SET_DEFAULT, od_SET_NULL:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params)
if err != nil {
return err
}
case od_DO_NOTHING:
}
}
return nil
}
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) {
tables := newDbTables(mi, d.ins)
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
}
if cond == nil || cond.IsEmpty() {
panic("delete operation cannot execute without condition")
}
where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql()
colsNum := len(mi.fields.pk)
cols := make([]string, colsNum)
for i, fi := range mi.fields.pk {
cols[i] = fi.column
}
colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref interface{}
refs[i] = &ref
}
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
for _, ref := range refs {
args = append(args, reflect.ValueOf(ref).Elem().Interface())
}
cnt++
}
if cnt == 0 {
return 0, nil
}
if colsNum > 1 {
columns := strings.Join(cols, "` = ? AND `")
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
} else {
var sql string
sql, args = d.ins.GetOperatorSql(mi, "in", args)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
}
if res, err := q.Exec(query, args...); err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if colsNum == 1 && num > 0 {
err := d.deleteRels(q, mi, args)
if err != nil {
return num, err
}
}
return num, nil
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
typ := ind.Type()
errTyp := true
one := true
if val.Kind() == reflect.Ptr {
tp := typ
if ind.Kind() == reflect.Slice {
one = false
if ind.Type().Elem().Kind() == reflect.Ptr {
tp = ind.Type().Elem().Elem()
}
}
errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName
}
if errTyp {
panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName))
}
rlimit := qs.limit
offset := qs.offset
if one {
rlimit = 0
offset = 0
}
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(offset, rlimit)
join := tables.getJoinSql()
colsNum := len(mi.fields.dbcols)
cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`"))
for _, tbl := range tables.tables {
if tbl.sel {
colsNum += len(tbl.mi.fields.dbcols)
cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`"))
}
}
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref interface{}
refs[i] = &ref
}
slice := ind
var cnt int64
for rs.Next() {
if one && cnt == 0 || one == false {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
elm := reflect.New(mi.addrField.Elem().Type())
md := elm.Interface().(Modeler)
md.Init(md)
mind := reflect.Indirect(elm)
cacheV := make(map[string]*reflect.Value)
cacheM := make(map[string]*modelInfo)
trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)])
trefs = refs[len(mi.fields.dbcols):]
for _, tbl := range tables.tables {
if tbl.sel {
last := mind
names := ""
mmi := mi
for _, name := range tbl.names {
names += name
if val, ok := cacheV[names]; ok {
last = *val
mmi = cacheM[names]
} else {
fi := mmi.fields.GetByName(name)
lastm := mmi
mmi := fi.relModelInfo
field := reflect.Indirect(last.Field(fi.fieldIndex))
if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)])
for _, fi := range mmi.fields.fieldsReverse {
if fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
field.Field(fi.fieldIndex).Set(last.Addr())
}
}
}
cacheV[names] = &field
cacheM[names] = mmi
last = field
}
trefs = trefs[len(mmi.fields.dbcols):]
}
}
}
}
if one {
ind.Set(mind)
} else {
slice = reflect.Append(slice, mind.Addr())
}
}
cnt++
}
if one == false {
ind.Set(slice)
}
return cnt, nil
}
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
tables.getOrderSql(qs.orders)
join := tables.getJoinSql()
query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where)
row := q.QueryRow(query, args...)
err = row.Scan(&cnt)
return
}
func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
params := make([]interface{}, len(args))
copy(params, args)
sql := ""
for i, arg := range args {
if len(mi.fields.pk) == 1 {
if md, ok := arg.(Modeler); ok {
ind := reflect.Indirect(reflect.ValueOf(md))
if _, values, exist := d.existPk(mi, ind); exist {
arg = values[0]
} else {
panic(fmt.Sprintf("`%s` need a valid args value", operator))
}
}
}
params[i] = arg
}
if operator == "in" {
marks := make([]string, len(params))
for i, _ := range marks {
marks[i] = "?"
}
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
} else {
if len(params) > 1 {
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = operatorsSQL[operator]
arg := params[0]
switch operator {
case "exact":
if arg == nil {
params[0] = "IS NULL"
}
case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith":
param := strings.Replace(ToStr(arg), `%`, `\%`, -1)
switch operator {
case "iexact":
case "contains", "icontains":
param = fmt.Sprintf("%%%s%%", param)
case "startswith", "istartswith":
param = fmt.Sprintf("%s%%", param)
case "endswith", "iendswith":
param = fmt.Sprintf("%%%s", param)
}
params[0] = param
case "isnull":
if b, ok := arg.(bool); ok {
if b {
sql = "IS NULL"
} else {
sql = "IS NOT NULL"
}
params = nil
} else {
panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg))
}
}
}
return sql, params
}
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
fi := mi.fields.GetByColumn(column)
field := ind.Field(fi.fieldIndex)
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
_, err = d.setValue(fi, value, &field)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
}
}
func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
if val == nil {
return nil, nil
}
var value interface{}
var str *StrTo
switch v := val.(type) {
case []byte:
s := StrTo(string(v))
str = &s
case string:
s := StrTo(v)
str = &s
}
fieldType := fi.fieldType
setValue:
switch {
case fieldType == TypeBooleanField:
if str == nil {
switch v := val.(type) {
case int64:
b := v == 1
value = b
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
b, err := str.Bool()
if err != nil {
return nil, err
}
value = b
}
case fieldType == TypeCharField || fieldType == TypeTextField:
s := str.String()
if str == nil {
s = ToStr(val)
}
value = s
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil {
switch v := val.(type) {
case time.Time:
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
format := format_DateTime
if fi.fieldType == TypeDateField {
format = format_Date
}
s := str.String()
t, err := timeParse(s, format)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
return nil, err
}
value = t
}
case fieldType&IsIntegerField > 0:
if str == nil {
s := StrTo(ToStr(val))
str = &s
}
if str != nil {
var err error
switch fieldType {
case TypeSmallIntegerField:
_, err = str.Int16()
case TypeIntegerField:
_, err = str.Int32()
case TypeBigIntegerField:
_, err = str.Int64()
case TypePositiveSmallIntegerField:
_, err = str.Uint16()
case TypePositiveIntegerField:
_, err = str.Uint32()
case TypePositiveBigIntegerField:
_, err = str.Uint64()
}
if err != nil {
return nil, err
}
if fieldType&IsPostiveIntegerField > 0 {
v, _ := str.Uint64()
value = v
} else {
v, _ := str.Int64()
value = v
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if str == nil {
switch v := val.(type) {
case float64:
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
v, err := str.Float64()
if err != nil {
return nil, err
}
value = v
}
case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType
goto setValue
}
return value, nil
}
func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
fieldType := fi.fieldType
isNative := fi.isFielder == false
setValue:
switch {
case fieldType == TypeBooleanField:
if isNative {
if value == nil {
value = false
}
field.SetBool(value.(bool))
}
case fieldType == TypeCharField || fieldType == TypeTextField:
if isNative {
if value == nil {
value = ""
}
field.SetString(value.(string))
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative {
if value == nil {
value = time.Time{}
}
field.Set(reflect.ValueOf(value))
}
case fieldType&IsIntegerField > 0:
if fieldType&IsPostiveIntegerField > 0 {
if isNative {
if value == nil {
value = uint64(0)
}
field.SetUint(value.(uint64))
}
} else {
if isNative {
if value == nil {
value = int64(0)
}
field.SetInt(value.(int64))
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if isNative {
if value == nil {
value = float64(0)
}
field.SetFloat(value.(float64))
}
case fieldType&IsRelField > 0:
if value != nil {
fieldType = fi.relModelInfo.fields.pk[0].fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler)
md.Init(md)
field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
field = &f
goto setValue
}
}
if isNative == false {
fd := field.Addr().Interface().(Fielder)
err := fd.SetRaw(value)
if err != nil {
return nil, err
}
}
return value, nil
}
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) {
var (
maps []Params
lists []ParamsList
list ParamsList
)
typ := 0
switch container.(type) {
case *[]Params:
typ = 1
case *[]ParamsList:
typ = 2
case *ParamsList:
typ = 3
default:
panic(fmt.Sprintf("unsupport read values type `%T`", container))
}
tables := newDbTables(mi, d.ins)
var (
cols []string
infos []*fieldInfo
)
hasExprs := len(exprs) > 0
if hasExprs {
cols = make([]string, 0, len(exprs))
infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs {
index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", ex))
}
cols = append(cols, fmt.Sprintf("%s.`%s` `%s`", index, col, name))
infos = append(infos, fi)
}
} else {
cols = make([]string, 0, len(mi.fields.dbcols))
infos = make([]*fieldInfo, 0, len(exprs))
for _, fi := range mi.fields.fieldsDB {
cols = append(cols, fmt.Sprintf("T0.`%s` `%s`", fi.column, fi.name))
infos = append(infos, fi)
}
}
where, args := tables.getCondSql(cond, false)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(qs.offset, qs.limit)
join := tables.getJoinSql()
sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, len(cols))
for i, _ := range refs {
var ref interface{}
refs[i] = &ref
}
var (
cnt int64
columns []string
)
for rs.Next() {
if cnt == 0 {
if cols, err := rs.Columns(); err != nil {
return 0, err
} else {
columns = cols
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
switch typ {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
params[columns[i]] = value
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
params = append(params, value)
}
lists = append(lists, params)
case 3:
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
list = append(list, value)
}
}
cnt++
}
switch v := container.(type) {
case *[]Params:
*v = maps
case *[]ParamsList:
*v = lists
case *ParamsList:
*v = list
}
return cnt, nil
}