mirror of
https://github.com/astaxie/beego.git
synced 2024-11-16 19:40:55 +00:00
906637ae8b
- Also add test covering the issue
477 lines
9.5 KiB
Go
477 lines
9.5 KiB
Go
// Copyright 2014 beego Author. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package orm
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// table info struct.
|
|
type dbTable struct {
|
|
id int
|
|
index string
|
|
name string
|
|
names []string
|
|
sel bool
|
|
inner bool
|
|
mi *modelInfo
|
|
fi *fieldInfo
|
|
jtl *dbTable
|
|
}
|
|
|
|
// tables collection struct, contains some tables.
|
|
type dbTables struct {
|
|
tablesM map[string]*dbTable
|
|
tables []*dbTable
|
|
mi *modelInfo
|
|
base dbBaser
|
|
skipEnd bool
|
|
}
|
|
|
|
// set table info to collection.
|
|
// if not exist, create new.
|
|
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
|
name := strings.Join(names, ExprSep)
|
|
if j, ok := t.tablesM[name]; ok {
|
|
j.name = name
|
|
j.mi = mi
|
|
j.fi = fi
|
|
j.inner = inner
|
|
} else {
|
|
i := len(t.tables) + 1
|
|
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
|
t.tablesM[name] = jt
|
|
t.tables = append(t.tables, jt)
|
|
}
|
|
return t.tablesM[name]
|
|
}
|
|
|
|
// add table info to collection.
|
|
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
|
name := strings.Join(names, ExprSep)
|
|
if _, ok := t.tablesM[name]; ok == false {
|
|
i := len(t.tables) + 1
|
|
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
|
t.tablesM[name] = jt
|
|
t.tables = append(t.tables, jt)
|
|
return jt, true
|
|
}
|
|
return t.tablesM[name], false
|
|
}
|
|
|
|
// get table info in collection.
|
|
func (t *dbTables) get(name string) (*dbTable, bool) {
|
|
j, ok := t.tablesM[name]
|
|
return j, ok
|
|
}
|
|
|
|
// get related fields info in recursive depth loop.
|
|
// loop once, depth decreases one.
|
|
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
|
if depth < 0 || fi.fieldType == RelManyToMany {
|
|
return related
|
|
}
|
|
|
|
if prefix == "" {
|
|
prefix = fi.name
|
|
} else {
|
|
prefix = prefix + ExprSep + fi.name
|
|
}
|
|
related = append(related, prefix)
|
|
|
|
depth--
|
|
for _, fi := range fi.relModelInfo.fields.fieldsRel {
|
|
related = t.loopDepth(depth, prefix, fi, related)
|
|
}
|
|
|
|
return related
|
|
}
|
|
|
|
// parse related fields.
|
|
func (t *dbTables) parseRelated(rels []string, depth int) {
|
|
|
|
relsNum := len(rels)
|
|
related := make([]string, relsNum)
|
|
copy(related, rels)
|
|
|
|
relDepth := depth
|
|
|
|
if relsNum != 0 {
|
|
relDepth = 0
|
|
}
|
|
|
|
relDepth--
|
|
for _, fi := range t.mi.fields.fieldsRel {
|
|
related = t.loopDepth(relDepth, "", fi, related)
|
|
}
|
|
|
|
for i, s := range related {
|
|
var (
|
|
exs = strings.Split(s, ExprSep)
|
|
names = make([]string, 0, len(exs))
|
|
mmi = t.mi
|
|
cancel = true
|
|
jtl *dbTable
|
|
)
|
|
|
|
inner := true
|
|
|
|
for _, ex := range exs {
|
|
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
|
|
names = append(names, fi.name)
|
|
mmi = fi.relModelInfo
|
|
|
|
if fi.null || t.skipEnd {
|
|
inner = false
|
|
}
|
|
|
|
jt := t.set(names, mmi, fi, inner)
|
|
jt.jtl = jtl
|
|
|
|
if fi.reverse {
|
|
cancel = false
|
|
}
|
|
|
|
if cancel {
|
|
jt.sel = depth > 0
|
|
|
|
if i < relsNum {
|
|
jt.sel = true
|
|
}
|
|
}
|
|
|
|
jtl = jt
|
|
|
|
} else {
|
|
panic(fmt.Errorf("unknown model/table name `%s`", ex))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// generate join string.
|
|
func (t *dbTables) getJoinSQL() (join string) {
|
|
Q := t.base.TableQuote()
|
|
|
|
for _, jt := range t.tables {
|
|
if jt.inner {
|
|
join += "INNER JOIN "
|
|
} else {
|
|
join += "LEFT OUTER JOIN "
|
|
}
|
|
var (
|
|
table string
|
|
t1, t2 string
|
|
c1, c2 string
|
|
)
|
|
t1 = "T0"
|
|
if jt.jtl != nil {
|
|
t1 = jt.jtl.index
|
|
}
|
|
t2 = jt.index
|
|
table = jt.mi.table
|
|
|
|
switch {
|
|
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
|
|
c1 = jt.fi.mi.fields.pk.column
|
|
for _, ffi := range jt.mi.fields.fieldsRel {
|
|
if jt.fi.mi == ffi.relModelInfo {
|
|
c2 = ffi.column
|
|
break
|
|
}
|
|
}
|
|
default:
|
|
c1 = jt.fi.column
|
|
c2 = jt.fi.relModelInfo.fields.pk.column
|
|
|
|
if jt.fi.reverse {
|
|
c1 = jt.mi.fields.pk.column
|
|
c2 = jt.fi.reverseFieldInfo.column
|
|
}
|
|
}
|
|
|
|
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
|
|
t2, Q, c2, Q, t1, Q, c1, Q)
|
|
}
|
|
return
|
|
}
|
|
|
|
// parse orm model struct field tag expression.
|
|
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
|
var (
|
|
jtl *dbTable
|
|
fi *fieldInfo
|
|
fiN *fieldInfo
|
|
mmi = mi
|
|
)
|
|
|
|
num := len(exprs) - 1
|
|
var names []string
|
|
|
|
inner := true
|
|
|
|
loopFor:
|
|
for i, ex := range exprs {
|
|
|
|
var ok, okN bool
|
|
|
|
if fiN != nil {
|
|
fi = fiN
|
|
ok = true
|
|
fiN = nil
|
|
}
|
|
|
|
if i == 0 {
|
|
fi, ok = mmi.fields.GetByAny(ex)
|
|
}
|
|
|
|
_ = okN
|
|
|
|
if ok {
|
|
|
|
isRel := fi.rel || fi.reverse
|
|
|
|
names = append(names, fi.name)
|
|
|
|
switch {
|
|
case fi.rel:
|
|
mmi = fi.relModelInfo
|
|
if fi.fieldType == RelManyToMany {
|
|
mmi = fi.relThroughModelInfo
|
|
}
|
|
case fi.reverse:
|
|
mmi = fi.reverseFieldInfo.mi
|
|
}
|
|
|
|
if i < num {
|
|
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
|
}
|
|
|
|
if isRel && (fi.mi.isThrough == false || num != i) {
|
|
if fi.null || t.skipEnd {
|
|
inner = false
|
|
}
|
|
|
|
if t.skipEnd && okN || !t.skipEnd {
|
|
if t.skipEnd && okN && fiN.pk {
|
|
goto loopEnd
|
|
}
|
|
|
|
jt, _ := t.add(names, mmi, fi, inner)
|
|
jt.jtl = jtl
|
|
jtl = jt
|
|
}
|
|
|
|
}
|
|
|
|
if num != i {
|
|
continue
|
|
}
|
|
|
|
loopEnd:
|
|
|
|
if i == 0 || jtl == nil {
|
|
index = "T0"
|
|
} else {
|
|
index = jtl.index
|
|
}
|
|
|
|
info = fi
|
|
|
|
if jtl == nil {
|
|
name = fi.name
|
|
} else {
|
|
name = jtl.name + ExprSep + fi.name
|
|
}
|
|
|
|
switch {
|
|
case fi.rel:
|
|
|
|
case fi.reverse:
|
|
switch fi.reverseFieldInfo.fieldType {
|
|
case RelOneToOne, RelForeignKey:
|
|
index = jtl.index
|
|
info = fi.reverseFieldInfo.mi.fields.pk
|
|
name = info.name
|
|
}
|
|
}
|
|
|
|
break loopFor
|
|
|
|
} else {
|
|
index = ""
|
|
name = ""
|
|
info = nil
|
|
success = false
|
|
return
|
|
}
|
|
}
|
|
|
|
success = index != "" && info != nil
|
|
return
|
|
}
|
|
|
|
// generate condition sql.
|
|
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
|
if cond == nil || cond.IsEmpty() {
|
|
return
|
|
}
|
|
|
|
Q := t.base.TableQuote()
|
|
|
|
mi := t.mi
|
|
|
|
for i, p := range cond.params {
|
|
if i > 0 {
|
|
if p.isOr {
|
|
where += "OR "
|
|
} else {
|
|
where += "AND "
|
|
}
|
|
}
|
|
if p.isNot {
|
|
where += "NOT "
|
|
}
|
|
if p.isCond {
|
|
w, ps := t.getCondSQL(p.cond, true, tz)
|
|
if w != "" {
|
|
w = fmt.Sprintf("( %s) ", w)
|
|
}
|
|
where += w
|
|
params = append(params, ps...)
|
|
} else {
|
|
exprs := p.exprs
|
|
|
|
num := len(exprs) - 1
|
|
operator := ""
|
|
if operators[exprs[num]] {
|
|
operator = exprs[num]
|
|
exprs = exprs[:num]
|
|
}
|
|
|
|
index, _, fi, suc := t.parseExprs(mi, exprs)
|
|
if suc == false {
|
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
|
}
|
|
|
|
if operator == "" {
|
|
operator = "exact"
|
|
}
|
|
|
|
operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
|
|
|
|
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
|
|
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
|
|
|
|
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
|
|
params = append(params, args...)
|
|
|
|
}
|
|
}
|
|
|
|
if sub == false && where != "" {
|
|
where = "WHERE " + where
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// generate group sql.
|
|
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
|
|
if len(groups) == 0 {
|
|
return
|
|
}
|
|
|
|
Q := t.base.TableQuote()
|
|
|
|
groupSqls := make([]string, 0, len(groups))
|
|
for _, group := range groups {
|
|
exprs := strings.Split(group, ExprSep)
|
|
|
|
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
|
if suc == false {
|
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
|
}
|
|
|
|
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
|
|
}
|
|
|
|
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
|
|
return
|
|
}
|
|
|
|
// generate order sql.
|
|
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
|
if len(orders) == 0 {
|
|
return
|
|
}
|
|
|
|
Q := t.base.TableQuote()
|
|
|
|
orderSqls := make([]string, 0, len(orders))
|
|
for _, order := range orders {
|
|
asc := "ASC"
|
|
if order[0] == '-' {
|
|
asc = "DESC"
|
|
order = order[1:]
|
|
}
|
|
exprs := strings.Split(order, ExprSep)
|
|
|
|
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
|
if suc == false {
|
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
|
}
|
|
|
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
|
|
}
|
|
|
|
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
|
return
|
|
}
|
|
|
|
// generate limit sql.
|
|
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
|
|
if limit == 0 {
|
|
limit = int64(DefaultRowsLimit)
|
|
}
|
|
if limit < 0 {
|
|
// no limit
|
|
if offset > 0 {
|
|
maxLimit := t.base.MaxLimit()
|
|
if maxLimit == 0 {
|
|
limits = fmt.Sprintf("OFFSET %d", offset)
|
|
} else {
|
|
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
|
|
}
|
|
}
|
|
} else if offset <= 0 {
|
|
limits = fmt.Sprintf("LIMIT %d", limit)
|
|
} else {
|
|
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
|
|
}
|
|
return
|
|
}
|
|
|
|
// crete new tables collection.
|
|
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
|
tables := &dbTables{}
|
|
tables.tablesM = make(map[string]*dbTable)
|
|
tables.mi = mi
|
|
tables.base = base
|
|
return tables
|
|
}
|