mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 18:30:56 +00:00
orm add sqlite3 support, may be support postgres in next commit
This commit is contained in:
parent
9631c663d5
commit
6c41e6dd78
557
orm/db.go
557
orm/db.go
@ -43,395 +43,8 @@ var (
|
|||||||
"isnull": true,
|
"isnull": true,
|
||||||
// "search": true,
|
// "search": true,
|
||||||
}
|
}
|
||||||
operatorsSQL = map[string]string{
|
|
||||||
"exact": "= ?",
|
|
||||||
"iexact": "LIKE ?",
|
|
||||||
"contains": "LIKE BINARY ?",
|
|
||||||
"icontains": "LIKE ?",
|
|
||||||
// "regex": "REGEXP BINARY ?",
|
|
||||||
// "iregex": "REGEXP ?",
|
|
||||||
"gt": "> ?",
|
|
||||||
"gte": ">= ?",
|
|
||||||
"lt": "< ?",
|
|
||||||
"lte": "<= ?",
|
|
||||||
"startswith": "LIKE BINARY ?",
|
|
||||||
"endswith": "LIKE BINARY ?",
|
|
||||||
"istartswith": "LIKE ?",
|
|
||||||
"iendswith": "LIKE ?",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type dbTable struct {
|
|
||||||
id int
|
|
||||||
index string
|
|
||||||
name string
|
|
||||||
names []string
|
|
||||||
sel bool
|
|
||||||
inner bool
|
|
||||||
mi *modelInfo
|
|
||||||
fi *fieldInfo
|
|
||||||
jtl *dbTable
|
|
||||||
}
|
|
||||||
|
|
||||||
type dbTables struct {
|
|
||||||
tablesM map[string]*dbTable
|
|
||||||
tables []*dbTable
|
|
||||||
mi *modelInfo
|
|
||||||
base dbBaser
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
|
||||||
name := strings.Join(names, ExprSep)
|
|
||||||
if j, ok := t.tablesM[name]; ok {
|
|
||||||
j.name = name
|
|
||||||
j.mi = mi
|
|
||||||
j.fi = fi
|
|
||||||
j.inner = inner
|
|
||||||
} else {
|
|
||||||
i := len(t.tables) + 1
|
|
||||||
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
|
||||||
t.tablesM[name] = jt
|
|
||||||
t.tables = append(t.tables, jt)
|
|
||||||
}
|
|
||||||
return t.tablesM[name]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
|
||||||
name := strings.Join(names, ExprSep)
|
|
||||||
if _, ok := t.tablesM[name]; ok == false {
|
|
||||||
i := len(t.tables) + 1
|
|
||||||
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
|
||||||
t.tablesM[name] = jt
|
|
||||||
t.tables = append(t.tables, jt)
|
|
||||||
return jt, true
|
|
||||||
}
|
|
||||||
return t.tablesM[name], false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) get(name string) (*dbTable, bool) {
|
|
||||||
j, ok := t.tablesM[name]
|
|
||||||
return j, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
|
||||||
if depth < 0 || fi.fieldType == RelManyToMany {
|
|
||||||
return related
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix == "" {
|
|
||||||
prefix = fi.name
|
|
||||||
} else {
|
|
||||||
prefix = prefix + ExprSep + fi.name
|
|
||||||
}
|
|
||||||
related = append(related, prefix)
|
|
||||||
|
|
||||||
depth--
|
|
||||||
for _, fi := range fi.relModelInfo.fields.fieldsRel {
|
|
||||||
related = t.loopDepth(depth, prefix, fi, related)
|
|
||||||
}
|
|
||||||
|
|
||||||
return related
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) parseRelated(rels []string, depth int) {
|
|
||||||
|
|
||||||
relsNum := len(rels)
|
|
||||||
related := make([]string, relsNum)
|
|
||||||
copy(related, rels)
|
|
||||||
|
|
||||||
relDepth := depth
|
|
||||||
|
|
||||||
if relsNum != 0 {
|
|
||||||
relDepth = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
relDepth--
|
|
||||||
for _, fi := range t.mi.fields.fieldsRel {
|
|
||||||
related = t.loopDepth(relDepth, "", fi, related)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, s := range related {
|
|
||||||
var (
|
|
||||||
exs = strings.Split(s, ExprSep)
|
|
||||||
names = make([]string, 0, len(exs))
|
|
||||||
mmi = t.mi
|
|
||||||
cansel = true
|
|
||||||
jtl *dbTable
|
|
||||||
)
|
|
||||||
for _, ex := range exs {
|
|
||||||
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
|
|
||||||
names = append(names, fi.name)
|
|
||||||
mmi = fi.relModelInfo
|
|
||||||
|
|
||||||
jt := t.set(names, mmi, fi, fi.null == false)
|
|
||||||
jt.jtl = jtl
|
|
||||||
|
|
||||||
if fi.reverse {
|
|
||||||
cansel = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if cansel {
|
|
||||||
jt.sel = depth > 0
|
|
||||||
|
|
||||||
if i < relsNum {
|
|
||||||
jt.sel = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
jtl = jt
|
|
||||||
|
|
||||||
} else {
|
|
||||||
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *dbTables) getJoinSql() (join string) {
|
|
||||||
for _, jt := range t.tables {
|
|
||||||
if jt.inner {
|
|
||||||
join += "INNER JOIN "
|
|
||||||
} else {
|
|
||||||
join += "LEFT OUTER JOIN "
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
table string
|
|
||||||
t1, t2 string
|
|
||||||
c1, c2 string
|
|
||||||
)
|
|
||||||
t1 = "T0"
|
|
||||||
if jt.jtl != nil {
|
|
||||||
t1 = jt.jtl.index
|
|
||||||
}
|
|
||||||
t2 = jt.index
|
|
||||||
table = jt.mi.table
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
|
|
||||||
c1 = jt.fi.mi.fields.pk.column
|
|
||||||
for _, ffi := range jt.mi.fields.fieldsRel {
|
|
||||||
if jt.fi.mi == ffi.relModelInfo {
|
|
||||||
c2 = ffi.column
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
c1 = jt.fi.column
|
|
||||||
c2 = jt.fi.relModelInfo.fields.pk.column
|
|
||||||
|
|
||||||
if jt.fi.reverse {
|
|
||||||
c1 = jt.mi.fields.pk.column
|
|
||||||
c2 = jt.fi.reverseFieldInfo.column
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2,
|
|
||||||
t2, c2, t1, c1)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
|
|
||||||
var (
|
|
||||||
ffi *fieldInfo
|
|
||||||
jtl *dbTable
|
|
||||||
mmi = mi
|
|
||||||
)
|
|
||||||
|
|
||||||
num := len(exprs) - 1
|
|
||||||
names := make([]string, 0)
|
|
||||||
|
|
||||||
for i, ex := range exprs {
|
|
||||||
exist := false
|
|
||||||
|
|
||||||
check:
|
|
||||||
fi, ok := mmi.fields.GetByAny(ex)
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
|
|
||||||
if num != i {
|
|
||||||
names = append(names, fi.name)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case fi.rel:
|
|
||||||
mmi = fi.relModelInfo
|
|
||||||
if fi.fieldType == RelManyToMany {
|
|
||||||
mmi = fi.relThroughModelInfo
|
|
||||||
}
|
|
||||||
case fi.reverse:
|
|
||||||
mmi = fi.reverseFieldInfo.mi
|
|
||||||
if fi.reverseFieldInfo.fieldType == RelManyToMany {
|
|
||||||
mmi = fi.reverseFieldInfo.relThroughModelInfo
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jt, _ := d.add(names, mmi, fi, fi.null == false)
|
|
||||||
jt.jtl = jtl
|
|
||||||
jtl = jt
|
|
||||||
|
|
||||||
if fi.rel && fi.fieldType == RelManyToMany {
|
|
||||||
ex = fi.relModelInfo.name
|
|
||||||
goto check
|
|
||||||
}
|
|
||||||
|
|
||||||
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
|
|
||||||
ex = fi.reverseFieldInfo.mi.name
|
|
||||||
goto check
|
|
||||||
}
|
|
||||||
|
|
||||||
exist = true
|
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
if ffi == nil {
|
|
||||||
index = "T0"
|
|
||||||
} else {
|
|
||||||
index = jtl.index
|
|
||||||
}
|
|
||||||
column = fi.column
|
|
||||||
info = fi
|
|
||||||
if jtl != nil {
|
|
||||||
name = jtl.name + ExprSep + fi.name
|
|
||||||
} else {
|
|
||||||
name = fi.name
|
|
||||||
}
|
|
||||||
|
|
||||||
switch fi.fieldType {
|
|
||||||
case RelManyToMany, RelReverseMany:
|
|
||||||
default:
|
|
||||||
exist = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ffi = fi
|
|
||||||
}
|
|
||||||
|
|
||||||
if exist == false {
|
|
||||||
index = ""
|
|
||||||
column = ""
|
|
||||||
name = ""
|
|
||||||
success = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
success = index != "" && column != ""
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
|
|
||||||
if cond == nil || cond.IsEmpty() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mi := d.mi
|
|
||||||
|
|
||||||
// outFor:
|
|
||||||
for i, p := range cond.params {
|
|
||||||
if i > 0 {
|
|
||||||
if p.isOr {
|
|
||||||
where += "OR "
|
|
||||||
} else {
|
|
||||||
where += "AND "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.isNot {
|
|
||||||
where += "NOT "
|
|
||||||
}
|
|
||||||
if p.isCond {
|
|
||||||
w, ps := d.getCondSql(p.cond, true)
|
|
||||||
if w != "" {
|
|
||||||
w = fmt.Sprintf("( %s) ", w)
|
|
||||||
}
|
|
||||||
where += w
|
|
||||||
params = append(params, ps...)
|
|
||||||
} else {
|
|
||||||
exprs := p.exprs
|
|
||||||
|
|
||||||
num := len(exprs) - 1
|
|
||||||
operator := ""
|
|
||||||
if operators[exprs[num]] {
|
|
||||||
operator = exprs[num]
|
|
||||||
exprs = exprs[:num]
|
|
||||||
}
|
|
||||||
|
|
||||||
index, column, _, _, suc := d.parseExprs(mi, exprs)
|
|
||||||
if suc == false {
|
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if operator == "" {
|
|
||||||
operator = "exact"
|
|
||||||
}
|
|
||||||
|
|
||||||
operSql, args := d.base.GetOperatorSql(mi, operator, p.args)
|
|
||||||
|
|
||||||
where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql)
|
|
||||||
params = append(params, args...)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sub == false && where != "" {
|
|
||||||
where = "WHERE " + where
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
|
|
||||||
if len(orders) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
orderSqls := make([]string, 0, len(orders))
|
|
||||||
for _, order := range orders {
|
|
||||||
asc := "ASC"
|
|
||||||
if order[0] == '-' {
|
|
||||||
asc = "DESC"
|
|
||||||
order = order[1:]
|
|
||||||
}
|
|
||||||
exprs := strings.Split(order, ExprSep)
|
|
||||||
|
|
||||||
index, column, _, _, suc := d.parseExprs(d.mi, exprs)
|
|
||||||
if suc == false {
|
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
|
||||||
}
|
|
||||||
|
|
||||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc))
|
|
||||||
}
|
|
||||||
|
|
||||||
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) {
|
|
||||||
if limit == 0 {
|
|
||||||
limit = DefaultRowsLimit
|
|
||||||
}
|
|
||||||
if limit < 0 {
|
|
||||||
// no limit
|
|
||||||
if offset > 0 {
|
|
||||||
limits = fmt.Sprintf("LIMIT 18446744073709551615 OFFSET %d", offset)
|
|
||||||
}
|
|
||||||
} else if offset <= 0 {
|
|
||||||
limits = fmt.Sprintf("LIMIT %d", limit)
|
|
||||||
} else {
|
|
||||||
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
|
||||||
tables := &dbTables{}
|
|
||||||
tables.tablesM = make(map[string]*dbTable)
|
|
||||||
tables.mi = mi
|
|
||||||
tables.base = base
|
|
||||||
return tables
|
|
||||||
}
|
|
||||||
|
|
||||||
type dbBase struct {
|
type dbBase struct {
|
||||||
ins dbBaser
|
ins dbBaser
|
||||||
}
|
}
|
||||||
@ -528,6 +141,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
||||||
marks := make([]string, 0, len(mi.fields.dbcols))
|
marks := make([]string, 0, len(mi.fields.dbcols))
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
@ -537,9 +152,13 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
qmarks := strings.Join(marks, ", ")
|
qmarks := strings.Join(marks, ", ")
|
||||||
columns := strings.Join(dbcols, "`,`")
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
columns := strings.Join(dbcols, sep)
|
||||||
|
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
|
|
||||||
stmt, err := q.Prepare(query)
|
stmt, err := q.Prepare(query)
|
||||||
return stmt, query, err
|
return stmt, query, err
|
||||||
}
|
}
|
||||||
@ -563,10 +182,13 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
|
|||||||
return ErrMissPK
|
return ErrMissPK
|
||||||
}
|
}
|
||||||
|
|
||||||
sels := strings.Join(mi.fields.dbcols, "`, `")
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
sels := strings.Join(mi.fields.dbcols, sep)
|
||||||
colsNum := len(mi.fields.dbcols)
|
colsNum := len(mi.fields.dbcols)
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
|
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
|
||||||
|
|
||||||
refs := make([]interface{}, colsNum)
|
refs := make([]interface{}, colsNum)
|
||||||
for i, _ := range refs {
|
for i, _ := range refs {
|
||||||
@ -574,6 +196,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
|
|||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
row := q.QueryRow(query, pkValue)
|
row := q.QueryRow(query, pkValue)
|
||||||
if err := row.Scan(refs...); err != nil {
|
if err := row.Scan(refs...); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@ -598,14 +222,20 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
for i, _ := range marks {
|
for i, _ := range marks {
|
||||||
marks[i] = "?"
|
marks[i] = "?"
|
||||||
}
|
}
|
||||||
qmarks := strings.Join(marks, ", ")
|
|
||||||
columns := strings.Join(names, "`,`")
|
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
qmarks := strings.Join(marks, ", ")
|
||||||
|
columns := strings.Join(names, sep)
|
||||||
|
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if res, err := q.Exec(query, values...); err == nil {
|
if res, err := q.Exec(query, values...); err == nil {
|
||||||
return res.LastInsertId()
|
return res.LastInsertId()
|
||||||
@ -624,12 +254,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
setColumns := strings.Join(setNames, "` = ?, `")
|
|
||||||
|
|
||||||
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
|
|
||||||
|
|
||||||
setValues = append(setValues, pkValue)
|
setValues = append(setValues, pkValue)
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
sep := fmt.Sprintf("%s = ?, %s", Q, Q)
|
||||||
|
setColumns := strings.Join(setNames, sep)
|
||||||
|
|
||||||
|
query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if res, err := q.Exec(query, setValues...); err == nil {
|
if res, err := q.Exec(query, setValues...); err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
} else {
|
} else {
|
||||||
@ -644,7 +279,11 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
|
|||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if res, err := q.Exec(query, pkValue); err == nil {
|
if res, err := q.Exec(query, pkValue); err == nil {
|
||||||
|
|
||||||
@ -694,11 +333,24 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
|
|
||||||
where, args := tables.getCondSql(cond, false)
|
where, args := tables.getCondSql(cond, false)
|
||||||
|
|
||||||
|
values = append(values, args...)
|
||||||
|
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where)
|
var query string
|
||||||
|
|
||||||
values = append(values, args...)
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
if d.ins.SupportUpdateJoin() {
|
||||||
|
cols := strings.Join(columns, fmt.Sprintf("%s = ?, T0.%s", Q, Q))
|
||||||
|
query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET T0.%s%s%s = ? %s", Q, mi.table, Q, join, Q, cols, Q, where)
|
||||||
|
} else {
|
||||||
|
cols := strings.Join(columns, fmt.Sprintf("%s = ?, %s", Q, Q))
|
||||||
|
supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where)
|
||||||
|
query = fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s IN ( %s )", Q, mi.table, Q, Q, cols, Q, Q, mi.fields.pk.column, Q, supQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if res, err := q.Exec(query, values...); err == nil {
|
if res, err := q.Exec(query, values...); err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -744,11 +396,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
panic("delete operation cannot execute without condition")
|
panic("delete operation cannot execute without condition")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
where, args := tables.getCondSql(cond, false)
|
where, args := tables.getCondSql(cond, false)
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
|
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
|
||||||
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
|
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
var rs *sql.Rows
|
||||||
if r, err := q.Query(query, args...); err != nil {
|
if r, err := q.Query(query, args...); err != nil {
|
||||||
@ -773,8 +429,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, args := d.ins.GetOperatorSql(mi, "in", args)
|
sql, args := d.ins.GenerateOperatorSql(mi, "in", args)
|
||||||
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
|
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if res, err := q.Exec(query, args...); err == nil {
|
if res, err := q.Exec(query, args...); err == nil {
|
||||||
num, err := res.RowsAffected()
|
num, err := res.RowsAffected()
|
||||||
@ -831,24 +489,30 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
offset = 0
|
offset = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
|
|
||||||
where, args := tables.getCondSql(cond, false)
|
where, args := tables.getCondSql(cond, false)
|
||||||
orderBy := tables.getOrderSql(qs.orders)
|
orderBy := tables.getOrderSql(qs.orders)
|
||||||
limit := tables.getLimitSql(offset, rlimit)
|
limit := tables.getLimitSql(mi, offset, rlimit)
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
colsNum := len(mi.fields.dbcols)
|
colsNum := len(mi.fields.dbcols)
|
||||||
cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`"))
|
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
|
||||||
|
cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
|
||||||
for _, tbl := range tables.tables {
|
for _, tbl := range tables.tables {
|
||||||
if tbl.sel {
|
if tbl.sel {
|
||||||
colsNum += len(tbl.mi.fields.dbcols)
|
colsNum += len(tbl.mi.fields.dbcols)
|
||||||
cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`"))
|
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
|
||||||
|
cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit)
|
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
var rs *sql.Rows
|
||||||
if r, err := q.Query(query, args...); err != nil {
|
if r, err := q.Query(query, args...); err != nil {
|
||||||
@ -940,7 +604,11 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
|||||||
tables.getOrderSql(qs.orders)
|
tables.getOrderSql(qs.orders)
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where)
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
row := q.QueryRow(query, args...)
|
row := q.QueryRow(query, args...)
|
||||||
|
|
||||||
@ -1014,7 +682,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
|
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
|
||||||
sql := ""
|
sql := ""
|
||||||
params := d.getOperatorParams(operator, args)
|
params := d.getOperatorParams(operator, args)
|
||||||
|
|
||||||
@ -1028,7 +696,7 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
|
|||||||
if len(params) > 1 {
|
if len(params) > 1 {
|
||||||
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
|
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
|
||||||
}
|
}
|
||||||
sql = operatorsSQL[operator]
|
sql = d.ins.OperatorSql(operator)
|
||||||
arg := params[0]
|
arg := params[0]
|
||||||
switch operator {
|
switch operator {
|
||||||
case "exact":
|
case "exact":
|
||||||
@ -1073,13 +741,13 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
|
|||||||
|
|
||||||
value, err := d.getValue(fi, val)
|
value, err := d.getValue(fi, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
|
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = d.setValue(fi, value, &field)
|
_, err = d.setValue(fi, value, &field)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
|
panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1090,6 +758,7 @@ func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var value interface{}
|
var value interface{}
|
||||||
|
var tErr error
|
||||||
|
|
||||||
var str *StrTo
|
var str *StrTo
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
@ -1119,7 +788,8 @@ setValue:
|
|||||||
if str != nil {
|
if str != nil {
|
||||||
b, err := str.Bool()
|
b, err := str.Bool()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
tErr = err
|
||||||
|
goto end
|
||||||
}
|
}
|
||||||
value = b
|
value = b
|
||||||
}
|
}
|
||||||
@ -1140,14 +810,23 @@ setValue:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if str != nil {
|
if str != nil {
|
||||||
format := format_DateTime
|
s := str.String()
|
||||||
|
var format string
|
||||||
if fi.fieldType == TypeDateField {
|
if fi.fieldType == TypeDateField {
|
||||||
format = format_Date
|
format = format_Date
|
||||||
|
if len(s) > 10 {
|
||||||
|
s = s[:10]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
format = format_DateTime
|
||||||
|
if len(s) > 19 {
|
||||||
|
s = s[:19]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s := str.String()
|
|
||||||
t, err := timeParse(s, format)
|
t, err := timeParse(s, format)
|
||||||
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
|
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
|
||||||
return nil, err
|
tErr = err
|
||||||
|
goto end
|
||||||
}
|
}
|
||||||
value = t
|
value = t
|
||||||
}
|
}
|
||||||
@ -1173,7 +852,8 @@ setValue:
|
|||||||
_, err = str.Uint64()
|
_, err = str.Uint64()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
tErr = err
|
||||||
|
goto end
|
||||||
}
|
}
|
||||||
if fieldType&IsPostiveIntegerField > 0 {
|
if fieldType&IsPostiveIntegerField > 0 {
|
||||||
v, _ := str.Uint64()
|
v, _ := str.Uint64()
|
||||||
@ -1196,15 +876,23 @@ setValue:
|
|||||||
if str != nil {
|
if str != nil {
|
||||||
v, err := str.Float64()
|
v, err := str.Float64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
tErr = err
|
||||||
|
goto end
|
||||||
}
|
}
|
||||||
value = v
|
value = v
|
||||||
}
|
}
|
||||||
case fieldType&IsRelField > 0:
|
case fieldType&IsRelField > 0:
|
||||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
fi = fi.relModelInfo.fields.pk
|
||||||
|
fieldType = fi.fieldType
|
||||||
goto setValue
|
goto setValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end:
|
||||||
|
if tErr != nil {
|
||||||
|
err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return value, nil
|
return value, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1275,6 +963,7 @@ setValue:
|
|||||||
fd := field.Addr().Interface().(Fielder)
|
fd := field.Addr().Interface().(Fielder)
|
||||||
err := fd.SetRaw(value)
|
err := fd.SetRaw(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1311,6 +1000,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
|
|
||||||
hasExprs := len(exprs) > 0
|
hasExprs := len(exprs) > 0
|
||||||
|
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
if hasExprs {
|
if hasExprs {
|
||||||
cols = make([]string, 0, len(exprs))
|
cols = make([]string, 0, len(exprs))
|
||||||
infos = make([]*fieldInfo, 0, len(exprs))
|
infos = make([]*fieldInfo, 0, len(exprs))
|
||||||
@ -1319,26 +1010,26 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
if suc == false {
|
if suc == false {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", ex))
|
panic(fmt.Errorf("unknown field/column name `%s`", ex))
|
||||||
}
|
}
|
||||||
cols = append(cols, fmt.Sprintf("%s.`%s` `%s`", index, col, name))
|
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q))
|
||||||
infos = append(infos, fi)
|
infos = append(infos, fi)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cols = make([]string, 0, len(mi.fields.dbcols))
|
cols = make([]string, 0, len(mi.fields.dbcols))
|
||||||
infos = make([]*fieldInfo, 0, len(exprs))
|
infos = make([]*fieldInfo, 0, len(exprs))
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
cols = append(cols, fmt.Sprintf("T0.`%s` `%s`", fi.column, fi.name))
|
cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q))
|
||||||
infos = append(infos, fi)
|
infos = append(infos, fi)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
where, args := tables.getCondSql(cond, false)
|
where, args := tables.getCondSql(cond, false)
|
||||||
orderBy := tables.getOrderSql(qs.orders)
|
orderBy := tables.getOrderSql(qs.orders)
|
||||||
limit := tables.getLimitSql(qs.offset, qs.limit)
|
limit := tables.getLimitSql(mi, qs.offset, qs.limit)
|
||||||
join := tables.getJoinSql()
|
join := tables.getJoinSql()
|
||||||
|
|
||||||
sels := strings.Join(cols, ", ")
|
sels := strings.Join(cols, ", ")
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit)
|
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
|
||||||
|
|
||||||
var rs *sql.Rows
|
var rs *sql.Rows
|
||||||
if r, err := q.Query(query, args...); err != nil {
|
if r, err := q.Query(query, args...); err != nil {
|
||||||
@ -1430,3 +1121,19 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
|
|
||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) SupportUpdateJoin() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) MaxLimit() uint64 {
|
||||||
|
return 18446744073709551615
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) TableQuote() string {
|
||||||
|
return "`"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) ReplaceMarks(query *string) {
|
||||||
|
// default use `?` as mark, do nothing
|
||||||
|
}
|
||||||
|
@ -1,11 +1,30 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
|
var mysqlOperators = map[string]string{
|
||||||
|
"exact": "= ?",
|
||||||
|
"iexact": "LIKE ?",
|
||||||
|
"contains": "LIKE BINARY ?",
|
||||||
|
"icontains": "LIKE ?",
|
||||||
|
// "regex": "REGEXP BINARY ?",
|
||||||
|
// "iregex": "REGEXP ?",
|
||||||
|
"gt": "> ?",
|
||||||
|
"gte": ">= ?",
|
||||||
|
"lt": "< ?",
|
||||||
|
"lte": "<= ?",
|
||||||
|
"startswith": "LIKE BINARY ?",
|
||||||
|
"endswith": "LIKE BINARY ?",
|
||||||
|
"istartswith": "LIKE ?",
|
||||||
|
"iendswith": "LIKE ?",
|
||||||
|
}
|
||||||
|
|
||||||
type dbBaseMysql struct {
|
type dbBaseMysql struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) {
|
var _ dbBaser = new(dbBaseMysql)
|
||||||
return d.dbBase.GetOperatorSql(mi, operator, args)
|
|
||||||
|
func (d *dbBaseMysql) OperatorSql(operator string) string {
|
||||||
|
return mysqlOperators[operator]
|
||||||
}
|
}
|
||||||
|
|
||||||
func newdbBaseMysql() dbBaser {
|
func newdbBaseMysql() dbBaser {
|
||||||
|
@ -4,6 +4,12 @@ type dbBaseOracle struct {
|
|||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ dbBaser = new(dbBaseOracle)
|
||||||
|
|
||||||
|
func (d *dbBase) OperatorSql(operator string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func newdbBaseOracle() dbBaser {
|
func newdbBaseOracle() dbBaser {
|
||||||
b := new(dbBaseOracle)
|
b := new(dbBaseOracle)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -1,9 +1,66 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var postgresOperators = map[string]string{
|
||||||
|
"exact": "= ?",
|
||||||
|
"iexact": "= UPPER(?)",
|
||||||
|
"contains": "LIKE ?",
|
||||||
|
"icontains": "LIKE UPPER(?)",
|
||||||
|
"gt": "> ?",
|
||||||
|
"gte": ">= ?",
|
||||||
|
"lt": "< ?",
|
||||||
|
"lte": "<= ?",
|
||||||
|
"startswith": "LIKE ?",
|
||||||
|
"endswith": "LIKE ?",
|
||||||
|
"istartswith": "LIKE UPPER(?)",
|
||||||
|
"iendswith": "LIKE UPPER(?)",
|
||||||
|
}
|
||||||
|
|
||||||
type dbBasePostgres struct {
|
type dbBasePostgres struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ dbBaser = new(dbBasePostgres)
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) OperatorSql(operator string) string {
|
||||||
|
return postgresOperators[operator]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) TableQuote() string {
|
||||||
|
return `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||||
|
q := *query
|
||||||
|
num := 0
|
||||||
|
for _, c := range q {
|
||||||
|
if c == '?' {
|
||||||
|
num += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if num == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := make([]byte, 0, len(q)+num)
|
||||||
|
num = 1
|
||||||
|
for i := 0; i < len(q); i++ {
|
||||||
|
c := q[i]
|
||||||
|
if c == '?' {
|
||||||
|
data = append(data, '$')
|
||||||
|
data = append(data, []byte(strconv.Itoa(num))...)
|
||||||
|
num += 1
|
||||||
|
} else {
|
||||||
|
data = append(data, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*query = string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (d *dbBasePostgres)
|
||||||
|
|
||||||
func newdbBasePostgres() dbBaser {
|
func newdbBasePostgres() dbBaser {
|
||||||
b := new(dbBasePostgres)
|
b := new(dbBasePostgres)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -1,9 +1,38 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
|
var sqliteOperators = map[string]string{
|
||||||
|
"exact": "= ?",
|
||||||
|
"iexact": "LIKE ? ESCAPE '\\'",
|
||||||
|
"contains": "LIKE ? ESCAPE '\\'",
|
||||||
|
"icontains": "LIKE ? ESCAPE '\\'",
|
||||||
|
"gt": "> ?",
|
||||||
|
"gte": ">= ?",
|
||||||
|
"lt": "< ?",
|
||||||
|
"lte": "<= ?",
|
||||||
|
"startswith": "LIKE ? ESCAPE '\\'",
|
||||||
|
"endswith": "LIKE ? ESCAPE '\\'",
|
||||||
|
"istartswith": "LIKE ? ESCAPE '\\'",
|
||||||
|
"iendswith": "LIKE ? ESCAPE '\\'",
|
||||||
|
}
|
||||||
|
|
||||||
type dbBaseSqlite struct {
|
type dbBaseSqlite struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ dbBaser = new(dbBaseSqlite)
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) OperatorSql(operator string) string {
|
||||||
|
return sqliteOperators[operator]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
||||||
|
return 9223372036854775807
|
||||||
|
}
|
||||||
|
|
||||||
func newdbBaseSqlite() dbBaser {
|
func newdbBaseSqlite() dbBaser {
|
||||||
b := new(dbBaseSqlite)
|
b := new(dbBaseSqlite)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
384
orm/db_tables.go
Normal file
384
orm/db_tables.go
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dbTable struct {
|
||||||
|
id int
|
||||||
|
index string
|
||||||
|
name string
|
||||||
|
names []string
|
||||||
|
sel bool
|
||||||
|
inner bool
|
||||||
|
mi *modelInfo
|
||||||
|
fi *fieldInfo
|
||||||
|
jtl *dbTable
|
||||||
|
}
|
||||||
|
|
||||||
|
type dbTables struct {
|
||||||
|
tablesM map[string]*dbTable
|
||||||
|
tables []*dbTable
|
||||||
|
mi *modelInfo
|
||||||
|
base dbBaser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
||||||
|
name := strings.Join(names, ExprSep)
|
||||||
|
if j, ok := t.tablesM[name]; ok {
|
||||||
|
j.name = name
|
||||||
|
j.mi = mi
|
||||||
|
j.fi = fi
|
||||||
|
j.inner = inner
|
||||||
|
} else {
|
||||||
|
i := len(t.tables) + 1
|
||||||
|
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
||||||
|
t.tablesM[name] = jt
|
||||||
|
t.tables = append(t.tables, jt)
|
||||||
|
}
|
||||||
|
return t.tablesM[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||||
|
name := strings.Join(names, ExprSep)
|
||||||
|
if _, ok := t.tablesM[name]; ok == false {
|
||||||
|
i := len(t.tables) + 1
|
||||||
|
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
||||||
|
t.tablesM[name] = jt
|
||||||
|
t.tables = append(t.tables, jt)
|
||||||
|
return jt, true
|
||||||
|
}
|
||||||
|
return t.tablesM[name], false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) get(name string) (*dbTable, bool) {
|
||||||
|
j, ok := t.tablesM[name]
|
||||||
|
return j, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
||||||
|
if depth < 0 || fi.fieldType == RelManyToMany {
|
||||||
|
return related
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = fi.name
|
||||||
|
} else {
|
||||||
|
prefix = prefix + ExprSep + fi.name
|
||||||
|
}
|
||||||
|
related = append(related, prefix)
|
||||||
|
|
||||||
|
depth--
|
||||||
|
for _, fi := range fi.relModelInfo.fields.fieldsRel {
|
||||||
|
related = t.loopDepth(depth, prefix, fi, related)
|
||||||
|
}
|
||||||
|
|
||||||
|
return related
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||||
|
|
||||||
|
relsNum := len(rels)
|
||||||
|
related := make([]string, relsNum)
|
||||||
|
copy(related, rels)
|
||||||
|
|
||||||
|
relDepth := depth
|
||||||
|
|
||||||
|
if relsNum != 0 {
|
||||||
|
relDepth = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
relDepth--
|
||||||
|
for _, fi := range t.mi.fields.fieldsRel {
|
||||||
|
related = t.loopDepth(relDepth, "", fi, related)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range related {
|
||||||
|
var (
|
||||||
|
exs = strings.Split(s, ExprSep)
|
||||||
|
names = make([]string, 0, len(exs))
|
||||||
|
mmi = t.mi
|
||||||
|
cansel = true
|
||||||
|
jtl *dbTable
|
||||||
|
)
|
||||||
|
for _, ex := range exs {
|
||||||
|
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
|
||||||
|
names = append(names, fi.name)
|
||||||
|
mmi = fi.relModelInfo
|
||||||
|
|
||||||
|
jt := t.set(names, mmi, fi, fi.null == false)
|
||||||
|
jt.jtl = jtl
|
||||||
|
|
||||||
|
if fi.reverse {
|
||||||
|
cansel = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if cansel {
|
||||||
|
jt.sel = depth > 0
|
||||||
|
|
||||||
|
if i < relsNum {
|
||||||
|
jt.sel = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jtl = jt
|
||||||
|
|
||||||
|
} else {
|
||||||
|
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dbTables) getJoinSql() (join string) {
|
||||||
|
Q := t.base.TableQuote()
|
||||||
|
|
||||||
|
for _, jt := range t.tables {
|
||||||
|
if jt.inner {
|
||||||
|
join += "INNER JOIN "
|
||||||
|
} else {
|
||||||
|
join += "LEFT OUTER JOIN "
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
table string
|
||||||
|
t1, t2 string
|
||||||
|
c1, c2 string
|
||||||
|
)
|
||||||
|
t1 = "T0"
|
||||||
|
if jt.jtl != nil {
|
||||||
|
t1 = jt.jtl.index
|
||||||
|
}
|
||||||
|
t2 = jt.index
|
||||||
|
table = jt.mi.table
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
|
||||||
|
c1 = jt.fi.mi.fields.pk.column
|
||||||
|
for _, ffi := range jt.mi.fields.fieldsRel {
|
||||||
|
if jt.fi.mi == ffi.relModelInfo {
|
||||||
|
c2 = ffi.column
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
c1 = jt.fi.column
|
||||||
|
c2 = jt.fi.relModelInfo.fields.pk.column
|
||||||
|
|
||||||
|
if jt.fi.reverse {
|
||||||
|
c1 = jt.mi.fields.pk.column
|
||||||
|
c2 = jt.fi.reverseFieldInfo.column
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
|
||||||
|
t2, Q, c2, Q, t1, Q, c1, Q)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
|
||||||
|
var (
|
||||||
|
ffi *fieldInfo
|
||||||
|
jtl *dbTable
|
||||||
|
mmi = mi
|
||||||
|
)
|
||||||
|
|
||||||
|
num := len(exprs) - 1
|
||||||
|
names := make([]string, 0)
|
||||||
|
|
||||||
|
for i, ex := range exprs {
|
||||||
|
exist := false
|
||||||
|
|
||||||
|
check:
|
||||||
|
fi, ok := mmi.fields.GetByAny(ex)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
|
||||||
|
if num != i {
|
||||||
|
names = append(names, fi.name)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case fi.rel:
|
||||||
|
mmi = fi.relModelInfo
|
||||||
|
if fi.fieldType == RelManyToMany {
|
||||||
|
mmi = fi.relThroughModelInfo
|
||||||
|
}
|
||||||
|
case fi.reverse:
|
||||||
|
mmi = fi.reverseFieldInfo.mi
|
||||||
|
if fi.reverseFieldInfo.fieldType == RelManyToMany {
|
||||||
|
mmi = fi.reverseFieldInfo.relThroughModelInfo
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jt, _ := d.add(names, mmi, fi, fi.null == false)
|
||||||
|
jt.jtl = jtl
|
||||||
|
jtl = jt
|
||||||
|
|
||||||
|
if fi.rel && fi.fieldType == RelManyToMany {
|
||||||
|
ex = fi.relModelInfo.name
|
||||||
|
goto check
|
||||||
|
}
|
||||||
|
|
||||||
|
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
|
||||||
|
ex = fi.reverseFieldInfo.mi.name
|
||||||
|
goto check
|
||||||
|
}
|
||||||
|
|
||||||
|
exist = true
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
if ffi == nil {
|
||||||
|
index = "T0"
|
||||||
|
} else {
|
||||||
|
index = jtl.index
|
||||||
|
}
|
||||||
|
column = fi.column
|
||||||
|
info = fi
|
||||||
|
if jtl != nil {
|
||||||
|
name = jtl.name + ExprSep + fi.name
|
||||||
|
} else {
|
||||||
|
name = fi.name
|
||||||
|
}
|
||||||
|
|
||||||
|
switch fi.fieldType {
|
||||||
|
case RelManyToMany, RelReverseMany:
|
||||||
|
default:
|
||||||
|
exist = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ffi = fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if exist == false {
|
||||||
|
index = ""
|
||||||
|
column = ""
|
||||||
|
name = ""
|
||||||
|
success = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
success = index != "" && column != ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
|
||||||
|
if cond == nil || cond.IsEmpty() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Q := d.base.TableQuote()
|
||||||
|
|
||||||
|
mi := d.mi
|
||||||
|
|
||||||
|
// outFor:
|
||||||
|
for i, p := range cond.params {
|
||||||
|
if i > 0 {
|
||||||
|
if p.isOr {
|
||||||
|
where += "OR "
|
||||||
|
} else {
|
||||||
|
where += "AND "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.isNot {
|
||||||
|
where += "NOT "
|
||||||
|
}
|
||||||
|
if p.isCond {
|
||||||
|
w, ps := d.getCondSql(p.cond, true)
|
||||||
|
if w != "" {
|
||||||
|
w = fmt.Sprintf("( %s) ", w)
|
||||||
|
}
|
||||||
|
where += w
|
||||||
|
params = append(params, ps...)
|
||||||
|
} else {
|
||||||
|
exprs := p.exprs
|
||||||
|
|
||||||
|
num := len(exprs) - 1
|
||||||
|
operator := ""
|
||||||
|
if operators[exprs[num]] {
|
||||||
|
operator = exprs[num]
|
||||||
|
exprs = exprs[:num]
|
||||||
|
}
|
||||||
|
|
||||||
|
index, column, _, _, suc := d.parseExprs(mi, exprs)
|
||||||
|
if suc == false {
|
||||||
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if operator == "" {
|
||||||
|
operator = "exact"
|
||||||
|
}
|
||||||
|
|
||||||
|
operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args)
|
||||||
|
|
||||||
|
where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql)
|
||||||
|
params = append(params, args...)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub == false && where != "" {
|
||||||
|
where = "WHERE " + where
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||||
|
if len(orders) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Q := d.base.TableQuote()
|
||||||
|
|
||||||
|
orderSqls := make([]string, 0, len(orders))
|
||||||
|
for _, order := range orders {
|
||||||
|
asc := "ASC"
|
||||||
|
if order[0] == '-' {
|
||||||
|
asc = "DESC"
|
||||||
|
order = order[1:]
|
||||||
|
}
|
||||||
|
exprs := strings.Split(order, ExprSep)
|
||||||
|
|
||||||
|
index, column, _, _, suc := d.parseExprs(d.mi, exprs)
|
||||||
|
if suc == false {
|
||||||
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc))
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits string) {
|
||||||
|
if limit == 0 {
|
||||||
|
limit = DefaultRowsLimit
|
||||||
|
}
|
||||||
|
if limit < 0 {
|
||||||
|
// no limit
|
||||||
|
if offset > 0 {
|
||||||
|
maxLimit := d.base.MaxLimit()
|
||||||
|
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
|
||||||
|
}
|
||||||
|
} else if offset <= 0 {
|
||||||
|
limits = fmt.Sprintf("LIMIT %d", limit)
|
||||||
|
} else {
|
||||||
|
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
||||||
|
tables := &dbTables{}
|
||||||
|
tables.tablesM = make(map[string]*dbTable)
|
||||||
|
tables.mi = mi
|
||||||
|
tables.base = base
|
||||||
|
return tables
|
||||||
|
}
|
@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
|
|||||||
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
||||||
info = new(modelInfo)
|
info = new(modelInfo)
|
||||||
info.fields = newFields()
|
info.fields = newFields()
|
||||||
info.table = m1.table + "_" + m2.table + "_rel"
|
info.table = m1.table + "_" + m2.table + "s"
|
||||||
info.name = camelString(info.table)
|
info.name = camelString(info.table)
|
||||||
info.fullName = m1.pkg + "." + info.name
|
info.fullName = m1.pkg + "." + info.name
|
||||||
|
|
||||||
|
@ -3,10 +3,11 @@ package orm
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/bmizerany/pq"
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,8 +96,178 @@ var DBARGS = struct {
|
|||||||
os.Getenv("ORM_DEBUG"),
|
os.Getenv("ORM_DEBUG"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
IsMysql = DBARGS.Driver == "mysql"
|
||||||
|
IsSqlite = DBARGS.Driver == "sqlite3"
|
||||||
|
IsPostgres = DBARGS.Driver == "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
var dORM Ormer
|
var dORM Ormer
|
||||||
|
|
||||||
|
var initSQLs = map[string]string{
|
||||||
|
"mysql": "DROP TABLE IF EXISTS `user_profile`;\n" +
|
||||||
|
"DROP TABLE IF EXISTS `user`;\n" +
|
||||||
|
"DROP TABLE IF EXISTS `post`;\n" +
|
||||||
|
"DROP TABLE IF EXISTS `tag`;\n" +
|
||||||
|
"DROP TABLE IF EXISTS `post_tags`;\n" +
|
||||||
|
"DROP TABLE IF EXISTS `comment`;\n" +
|
||||||
|
"CREATE TABLE `user_profile` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `age` smallint NOT NULL,\n" +
|
||||||
|
" `money` double precision NOT NULL\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE TABLE `user` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `user_name` varchar(30) NOT NULL UNIQUE,\n" +
|
||||||
|
" `email` varchar(100) NOT NULL,\n" +
|
||||||
|
" `password` varchar(100) NOT NULL,\n" +
|
||||||
|
" `status` smallint NOT NULL,\n" +
|
||||||
|
" `is_staff` bool NOT NULL,\n" +
|
||||||
|
" `is_active` bool NOT NULL,\n" +
|
||||||
|
" `created` date NOT NULL,\n" +
|
||||||
|
" `updated` datetime NOT NULL,\n" +
|
||||||
|
" `profile_id` integer\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE TABLE `post` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `user_id` integer NOT NULL,\n" +
|
||||||
|
" `title` varchar(60) NOT NULL,\n" +
|
||||||
|
" `content` longtext NOT NULL,\n" +
|
||||||
|
" `created` datetime NOT NULL,\n" +
|
||||||
|
" `updated` datetime NOT NULL\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE TABLE `tag` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `name` varchar(30) NOT NULL\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE TABLE `post_tags` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `post_id` integer NOT NULL,\n" +
|
||||||
|
" `tag_id` integer NOT NULL,\n" +
|
||||||
|
" UNIQUE (`post_id`, `tag_id`)\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE TABLE `comment` (\n" +
|
||||||
|
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
|
||||||
|
" `post_id` integer NOT NULL,\n" +
|
||||||
|
" `content` longtext NOT NULL,\n" +
|
||||||
|
" `parent_id` integer,\n" +
|
||||||
|
" `created` datetime NOT NULL\n" +
|
||||||
|
") ENGINE=INNODB;\n" +
|
||||||
|
"CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
|
||||||
|
"CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
|
||||||
|
"CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
|
||||||
|
"CREATE INDEX `comment_63f17a16` ON `comment` (`parent_id`);",
|
||||||
|
|
||||||
|
"sqlite3": `
|
||||||
|
DROP TABLE IF EXISTS "user_profile";
|
||||||
|
DROP TABLE IF EXISTS "user";
|
||||||
|
DROP TABLE IF EXISTS "post";
|
||||||
|
DROP TABLE IF EXISTS "tag";
|
||||||
|
DROP TABLE IF EXISTS "post_tags";
|
||||||
|
DROP TABLE IF EXISTS "comment";
|
||||||
|
CREATE TABLE "user_profile" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"age" smallint NOT NULL,
|
||||||
|
"money" real NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "user" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"user_name" varchar(30) NOT NULL UNIQUE,
|
||||||
|
"email" varchar(100) NOT NULL,
|
||||||
|
"password" varchar(100) NOT NULL,
|
||||||
|
"status" smallint NOT NULL,
|
||||||
|
"is_staff" bool NOT NULL,
|
||||||
|
"is_active" bool NOT NULL,
|
||||||
|
"created" date NOT NULL,
|
||||||
|
"updated" datetime NOT NULL,
|
||||||
|
"profile_id" integer
|
||||||
|
);
|
||||||
|
CREATE TABLE "post" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"user_id" integer NOT NULL,
|
||||||
|
"title" varchar(60) NOT NULL,
|
||||||
|
"content" text NOT NULL,
|
||||||
|
"created" datetime NOT NULL,
|
||||||
|
"updated" datetime NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "tag" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"name" varchar(30) NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "post_tags" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"post_id" integer NOT NULL,
|
||||||
|
"tag_id" integer NOT NULL,
|
||||||
|
UNIQUE ("post_id", "tag_id")
|
||||||
|
);
|
||||||
|
CREATE TABLE "comment" (
|
||||||
|
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
"post_id" integer NOT NULL,
|
||||||
|
"content" text NOT NULL,
|
||||||
|
"parent_id" integer,
|
||||||
|
"created" datetime NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
|
||||||
|
CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
|
||||||
|
CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
|
||||||
|
CREATE INDEX "comment_63f17a16" ON "comment" ("parent_id");
|
||||||
|
`,
|
||||||
|
|
||||||
|
"postgres": `
|
||||||
|
DROP TABLE IF EXISTS "user_profile";
|
||||||
|
DROP TABLE IF EXISTS "user";
|
||||||
|
DROP TABLE IF EXISTS "post";
|
||||||
|
DROP TABLE IF EXISTS "tag";
|
||||||
|
DROP TABLE IF EXISTS "post_tags";
|
||||||
|
DROP TABLE IF EXISTS "comment";
|
||||||
|
CREATE TABLE "user_profile" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"age" smallint NOT NULL,
|
||||||
|
"money" double precision NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "user" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"user_name" varchar(30) NOT NULL UNIQUE,
|
||||||
|
"email" varchar(100) NOT NULL,
|
||||||
|
"password" varchar(100) NOT NULL,
|
||||||
|
"status" smallint NOT NULL,
|
||||||
|
"is_staff" boolean NOT NULL,
|
||||||
|
"is_active" boolean NOT NULL,
|
||||||
|
"created" date NOT NULL,
|
||||||
|
"updated" timestamp with time zone NOT NULL,
|
||||||
|
"profile_id" integer
|
||||||
|
);
|
||||||
|
CREATE TABLE "post" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"user_id" integer NOT NULL,
|
||||||
|
"title" varchar(60) NOT NULL,
|
||||||
|
"content" text NOT NULL,
|
||||||
|
"created" timestamp with time zone NOT NULL,
|
||||||
|
"updated" timestamp with time zone NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "tag" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"name" varchar(30) NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE "post_tags" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"post_id" integer NOT NULL,
|
||||||
|
"tag_id" integer NOT NULL,
|
||||||
|
UNIQUE ("post_id", "tag_id")
|
||||||
|
);
|
||||||
|
CREATE TABLE "comment" (
|
||||||
|
"id" serial NOT NULL PRIMARY KEY,
|
||||||
|
"post_id" integer NOT NULL,
|
||||||
|
"content" text NOT NULL,
|
||||||
|
"parent_id" integer,
|
||||||
|
"created" timestamp with time zone NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX "user_profile_id" ON "user" ("profile_id");
|
||||||
|
CREATE INDEX "post_user_id" ON "post" ("user_id");
|
||||||
|
CREATE INDEX "comment_post_id" ON "comment" ("post_id");
|
||||||
|
CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
|
||||||
|
`}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterModel(new(User))
|
RegisterModel(new(User))
|
||||||
RegisterModel(new(Profile))
|
RegisterModel(new(Profile))
|
||||||
@ -114,7 +285,7 @@ Default DB Drivers.
|
|||||||
driver: url
|
driver: url
|
||||||
mysql: https://github.com/go-sql-driver/mysql
|
mysql: https://github.com/go-sql-driver/mysql
|
||||||
sqlite3: https://github.com/mattn/go-sqlite3
|
sqlite3: https://github.com/mattn/go-sqlite3
|
||||||
postgres: https://github.com/bmizerany/pq
|
postgres: https://github.com/lib/pq
|
||||||
|
|
||||||
eg: mysql
|
eg: mysql
|
||||||
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
|
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
|
||||||
@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a
|
|||||||
|
|
||||||
BootStrap()
|
BootStrap()
|
||||||
|
|
||||||
truncateTables()
|
|
||||||
|
|
||||||
dORM = NewOrm()
|
dORM = NewOrm()
|
||||||
}
|
|
||||||
|
|
||||||
func truncateTables() {
|
queries := strings.Split(initSQLs[DBARGS.Driver], ";")
|
||||||
logs := "truncate tables for test\n"
|
|
||||||
o := NewOrm()
|
for _, query := range queries {
|
||||||
for _, m := range modelCache.allOrdered() {
|
if strings.TrimSpace(query) == "" {
|
||||||
query := fmt.Sprintf("truncate table `%s`", m.table)
|
continue
|
||||||
_, err := o.Raw(query).Exec()
|
}
|
||||||
logs += query + "\n"
|
_, err := dORM.Raw(query).Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(logs)
|
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error {
|
|||||||
|
|
||||||
func (d *dbQueryLog) Rollback() error {
|
func (d *dbQueryLog) Rollback() error {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
err := d.db.(txEnder).Commit()
|
err := d.db.(txEnder).Rollback()
|
||||||
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
|
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -6,39 +6,17 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getResult(res sql.Result) (int64, error) {
|
|
||||||
if num, err := res.LastInsertId(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
} else {
|
|
||||||
if num > 0 {
|
|
||||||
return num, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if num, err := res.RowsAffected(); err != nil {
|
|
||||||
return num, err
|
|
||||||
} else {
|
|
||||||
if num > 0 {
|
|
||||||
return num, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type rawPrepare struct {
|
type rawPrepare struct {
|
||||||
rs *rawSet
|
rs *rawSet
|
||||||
stmt stmtQuerier
|
stmt stmtQuerier
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawPrepare) Exec(args ...interface{}) (int64, error) {
|
func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
|
||||||
if o.closed {
|
if o.closed {
|
||||||
return 0, ErrStmtClosed
|
return nil, ErrStmtClosed
|
||||||
}
|
}
|
||||||
res, err := o.stmt.Exec(args...)
|
return o.stmt.Exec(args...)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return getResult(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawPrepare) Close() error {
|
func (o *rawPrepare) Close() error {
|
||||||
@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) Exec() (int64, error) {
|
func (o *rawSet) Exec() (sql.Result, error) {
|
||||||
res, err := o.orm.db.Exec(o.query, o.args...)
|
return o.orm.db.Exec(o.query, o.args...)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return getResult(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) QueryRow(...interface{}) error {
|
func (o *rawSet) QueryRow(...interface{}) error {
|
||||||
|
219
orm/orm_test.go
219
orm/orm_test.go
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -12,6 +13,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ = os.PathSeparator
|
||||||
|
|
||||||
type T_Code int
|
type T_Code int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -60,9 +63,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
|
|||||||
ok = is && ok || !is && !ok
|
ok = is && ok || !is && !ok
|
||||||
if !ok {
|
if !ok {
|
||||||
if is {
|
if is {
|
||||||
err = fmt.Errorf("should: a == b, a = `%v`, b = `%v`", a, b)
|
err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a)
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("should: a != b, a = `%v`, b = `%v`", a, b)
|
err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case T_Less, T_Large:
|
case T_Less, T_Large:
|
||||||
@ -89,9 +92,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
|
|||||||
ok = is && ok || !is && !ok
|
ok = is && ok || !is && !ok
|
||||||
if !ok {
|
if !ok {
|
||||||
if is {
|
if is {
|
||||||
err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[0], f1, f2)
|
err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[0], f1, f2)
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[1], f1, f2)
|
err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[1], f1, f2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,32 +125,51 @@ func getCaller(skip int) string {
|
|||||||
fun := runtime.FuncForPC(pc)
|
fun := runtime.FuncForPC(pc)
|
||||||
_, fn := filepath.Split(file)
|
_, fn := filepath.Split(file)
|
||||||
data, err := ioutil.ReadFile(file)
|
data, err := ioutil.ReadFile(file)
|
||||||
code := ""
|
var codes []string
|
||||||
if err == nil {
|
if err == nil {
|
||||||
lines := bytes.Split(data, []byte{'\n'})
|
lines := bytes.Split(data, []byte{'\n'})
|
||||||
code = strings.TrimSpace(string(lines[line-1]))
|
n := 10
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
o := line - n
|
||||||
|
if o < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cur := o + i + 1
|
||||||
|
flag := " "
|
||||||
|
if cur == line {
|
||||||
|
flag = ">>"
|
||||||
|
}
|
||||||
|
code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i])))
|
||||||
|
if code != "" {
|
||||||
|
codes = append(codes, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
funName := fun.Name()
|
funName := fun.Name()
|
||||||
if i := strings.LastIndex(funName, "."); i > -1 {
|
if i := strings.LastIndex(funName, "."); i > -1 {
|
||||||
funName = funName[i+1:]
|
funName = funName[i+1:]
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s:%d: %s: %s", fn, line, funName, code)
|
return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func throwFail(t *testing.T, err error, args ...interface{}) {
|
func throwFail(t *testing.T, err error, args ...interface{}) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
|
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
|
||||||
params = append(params, args...)
|
if len(args) > 0 {
|
||||||
t.Error(params...)
|
con += fmt.Sprint(args...)
|
||||||
|
}
|
||||||
|
t.Error(con)
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func throwFailNow(t *testing.T, err error, args ...interface{}) {
|
func throwFailNow(t *testing.T, err error, args ...interface{}) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
|
con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
|
||||||
params = append(params, args...)
|
if len(args) > 0 {
|
||||||
t.Error(params...)
|
con += fmt.Sprint(args...)
|
||||||
|
}
|
||||||
|
t.Error(con)
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,8 +187,8 @@ func TestCRUD(t *testing.T) {
|
|||||||
profile.Age = 30
|
profile.Age = 30
|
||||||
profile.Money = 1234.12
|
profile.Money = 1234.12
|
||||||
id, err := dORM.Insert(profile)
|
id, err := dORM.Insert(profile)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 1))
|
||||||
|
|
||||||
user := NewUser()
|
user := NewUser()
|
||||||
user.UserName = "slene"
|
user.UserName = "slene"
|
||||||
@ -177,51 +199,53 @@ func TestCRUD(t *testing.T) {
|
|||||||
user.IsActive = true
|
user.IsActive = true
|
||||||
|
|
||||||
id, err = dORM.Insert(user)
|
id, err = dORM.Insert(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 1))
|
||||||
|
|
||||||
u := &User{Id: user.Id}
|
u := &User{Id: user.Id}
|
||||||
err = dORM.Read(u)
|
err = dORM.Read(u)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
throwFailNow(t, AssertIs(u.UserName, T_Equal, "slene"))
|
throwFail(t, AssertIs(u.UserName, T_Equal, "slene"))
|
||||||
throwFailNow(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com"))
|
throwFail(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com"))
|
||||||
throwFailNow(t, AssertIs(u.Password, T_Equal, "pass"))
|
throwFail(t, AssertIs(u.Password, T_Equal, "pass"))
|
||||||
throwFailNow(t, AssertIs(u.Status, T_Equal, 3))
|
throwFail(t, AssertIs(u.Status, T_Equal, 3))
|
||||||
throwFailNow(t, AssertIs(u.IsStaff, T_Equal, true))
|
throwFail(t, AssertIs(u.IsStaff, T_Equal, true))
|
||||||
throwFailNow(t, AssertIs(u.IsActive, T_Equal, true))
|
throwFail(t, AssertIs(u.IsActive, T_Equal, true))
|
||||||
throwFailNow(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
|
throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
|
||||||
throwFailNow(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
|
throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
|
||||||
|
|
||||||
user.UserName = "astaxie"
|
user.UserName = "astaxie"
|
||||||
user.Profile = profile
|
user.Profile = profile
|
||||||
num, err := dORM.Update(user)
|
num, err := dORM.Update(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
u = &User{Id: user.Id}
|
u = &User{Id: user.Id}
|
||||||
err = dORM.Read(u)
|
err = dORM.Read(u)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
throwFailNow(t, AssertIs(u.UserName, T_Equal, "astaxie"))
|
if err == nil {
|
||||||
throwFailNow(t, AssertIs(u.Profile.Id, T_Equal, profile.Id))
|
throwFail(t, AssertIs(u.UserName, T_Equal, "astaxie"))
|
||||||
|
throwFail(t, AssertIs(u.Profile.Id, T_Equal, profile.Id))
|
||||||
|
}
|
||||||
|
|
||||||
num, err = dORM.Delete(profile)
|
num, err = dORM.Delete(profile)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
u = &User{Id: user.Id}
|
u = &User{Id: user.Id}
|
||||||
err = dORM.Read(u)
|
err = dORM.Read(u)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(true, T_Equal, u.Profile == nil))
|
throwFail(t, AssertIs(true, T_Equal, u.Profile == nil))
|
||||||
|
|
||||||
num, err = dORM.Delete(user)
|
num, err = dORM.Delete(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
u = &User{Id: 100}
|
u = &User{Id: 100}
|
||||||
err = dORM.Read(u)
|
err = dORM.Read(u)
|
||||||
throwFailNow(t, AssertIs(err, T_Equal, ErrNoRows))
|
throwFail(t, AssertIs(err, T_Equal, ErrNoRows))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertTestData(t *testing.T) {
|
func TestInsertTestData(t *testing.T) {
|
||||||
@ -232,8 +256,8 @@ func TestInsertTestData(t *testing.T) {
|
|||||||
profile.Money = 1234.12
|
profile.Money = 1234.12
|
||||||
|
|
||||||
id, err := dORM.Insert(profile)
|
id, err := dORM.Insert(profile)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 2))
|
||||||
|
|
||||||
user := NewUser()
|
user := NewUser()
|
||||||
user.UserName = "slene"
|
user.UserName = "slene"
|
||||||
@ -247,16 +271,16 @@ func TestInsertTestData(t *testing.T) {
|
|||||||
users = append(users, user)
|
users = append(users, user)
|
||||||
|
|
||||||
id, err = dORM.Insert(user)
|
id, err = dORM.Insert(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 2))
|
||||||
|
|
||||||
profile = NewProfile()
|
profile = NewProfile()
|
||||||
profile.Age = 30
|
profile.Age = 30
|
||||||
profile.Money = 4321.09
|
profile.Money = 4321.09
|
||||||
|
|
||||||
id, err = dORM.Insert(profile)
|
id, err = dORM.Insert(profile)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 3))
|
||||||
|
|
||||||
user = NewUser()
|
user = NewUser()
|
||||||
user.UserName = "astaxie"
|
user.UserName = "astaxie"
|
||||||
@ -270,8 +294,8 @@ func TestInsertTestData(t *testing.T) {
|
|||||||
users = append(users, user)
|
users = append(users, user)
|
||||||
|
|
||||||
id, err = dORM.Insert(user)
|
id, err = dORM.Insert(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 3))
|
||||||
|
|
||||||
user = NewUser()
|
user = NewUser()
|
||||||
user.UserName = "nobody"
|
user.UserName = "nobody"
|
||||||
@ -284,8 +308,8 @@ func TestInsertTestData(t *testing.T) {
|
|||||||
users = append(users, user)
|
users = append(users, user)
|
||||||
|
|
||||||
id, err = dORM.Insert(user)
|
id, err = dORM.Insert(user)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Equal, 4))
|
||||||
|
|
||||||
tags := []*Tag{
|
tags := []*Tag{
|
||||||
&Tag{Name: "golang"},
|
&Tag{Name: "golang"},
|
||||||
@ -315,21 +339,21 @@ The program—and web server—godoc processes Go source files to extract docume
|
|||||||
|
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
id, err := dORM.Insert(tag)
|
id, err := dORM.Insert(tag)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, post := range posts {
|
for _, post := range posts {
|
||||||
id, err := dORM.Insert(post)
|
id, err := dORM.Insert(post)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
// dORM.M2mAdd(post, "tags", post.Tags)
|
// dORM.M2mAdd(post, "tags", post.Tags)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, comment := range comments {
|
for _, comment := range comments {
|
||||||
id, err := dORM.Insert(comment)
|
id, err := dORM.Insert(comment)
|
||||||
throwFailNow(t, err)
|
throwFail(t, err)
|
||||||
throwFailNow(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -359,9 +383,17 @@ func TestOperators(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 2))
|
throwFail(t, AssertIs(num, T_Equal, 2))
|
||||||
|
|
||||||
|
var shouldNum int
|
||||||
|
|
||||||
|
if IsSqlite {
|
||||||
|
shouldNum = 2
|
||||||
|
} else {
|
||||||
|
shouldNum = 0
|
||||||
|
}
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__contains", "E").Count()
|
num, err = qs.Filter("user_name__contains", "E").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 0))
|
throwFail(t, AssertIs(num, T_Equal, shouldNum))
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__icontains", "E").Count()
|
num, err = qs.Filter("user_name__icontains", "E").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -391,9 +423,15 @@ func TestOperators(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
|
if IsSqlite {
|
||||||
|
shouldNum = 1
|
||||||
|
} else {
|
||||||
|
shouldNum = 0
|
||||||
|
}
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__startswith", "S").Count()
|
num, err = qs.Filter("user_name__startswith", "S").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 0))
|
throwFail(t, AssertIs(num, T_Equal, shouldNum))
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__istartswith", "S").Count()
|
num, err = qs.Filter("user_name__istartswith", "S").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -403,9 +441,15 @@ func TestOperators(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 2))
|
throwFail(t, AssertIs(num, T_Equal, 2))
|
||||||
|
|
||||||
|
if IsSqlite {
|
||||||
|
shouldNum = 2
|
||||||
|
} else {
|
||||||
|
shouldNum = 0
|
||||||
|
}
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__endswith", "E").Count()
|
num, err = qs.Filter("user_name__endswith", "E").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 0))
|
throwFail(t, AssertIs(num, T_Equal, shouldNum))
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__iendswith", "E").Count()
|
num, err = qs.Filter("user_name__iendswith", "E").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -537,7 +581,6 @@ func TestRelatedSel(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
throwFail(t, AssertNot(user.Profile, T_Equal, nil))
|
throwFail(t, AssertNot(user.Profile, T_Equal, nil))
|
||||||
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
|
|
||||||
if user.Profile != nil {
|
if user.Profile != nil {
|
||||||
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
|
throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
|
||||||
}
|
}
|
||||||
@ -617,7 +660,7 @@ func TestOrderBy(t *testing.T) {
|
|||||||
func TestPrepareInsert(t *testing.T) {
|
func TestPrepareInsert(t *testing.T) {
|
||||||
qs := dORM.QueryTable("user")
|
qs := dORM.QueryTable("user")
|
||||||
i, err := qs.PrepareInsert()
|
i, err := qs.PrepareInsert()
|
||||||
throwFail(t, err)
|
throwFailNow(t, err)
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
user.UserName = "testing1"
|
user.UserName = "testing1"
|
||||||
@ -641,15 +684,18 @@ func TestPrepareInsert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRaw(t *testing.T) {
|
func TestRaw(t *testing.T) {
|
||||||
switch dORM.Driver().Type() {
|
switch {
|
||||||
case DR_MySQL:
|
case IsMysql || IsSqlite:
|
||||||
num, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec()
|
|
||||||
throwFail(t, err)
|
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
|
||||||
|
|
||||||
num, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec()
|
res, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
num, err := res.RowsAffected()
|
||||||
|
throwFail(t, AssertIs(num, T_Equal, 1), err)
|
||||||
|
|
||||||
|
res, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec()
|
||||||
|
throwFail(t, err)
|
||||||
|
num, err = res.RowsAffected()
|
||||||
|
throwFail(t, AssertIs(num, T_Equal, 1), err)
|
||||||
|
|
||||||
var maps []Params
|
var maps []Params
|
||||||
num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
|
num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
|
||||||
@ -681,11 +727,18 @@ func TestRaw(t *testing.T) {
|
|||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
qs := dORM.QueryTable("user")
|
qs := dORM.QueryTable("user")
|
||||||
num, err := qs.Filter("user_name", "slene").Update(Params{
|
num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{
|
||||||
"is_staff": true,
|
"is_staff": true,
|
||||||
})
|
})
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
|
// with join
|
||||||
|
num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{
|
||||||
|
"is_staff": false,
|
||||||
|
})
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
@ -701,48 +754,54 @@ func TestDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTransaction(t *testing.T) {
|
func TestTransaction(t *testing.T) {
|
||||||
|
// this test worked when database support transaction
|
||||||
|
|
||||||
o := NewOrm()
|
o := NewOrm()
|
||||||
err := o.Begin()
|
err := o.Begin()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
var names = []string{"1", "2", "3"}
|
var names = []string{"1", "2", "3"}
|
||||||
|
|
||||||
var user User
|
var tag Tag
|
||||||
user.UserName = names[0]
|
tag.Name = names[0]
|
||||||
id, err := o.Insert(&user)
|
id, err := o.Insert(&tag)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
|
|
||||||
num, err := o.QueryTable("user").Filter("user_name", "slene").Update(Params{"user_name": names[1]})
|
num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]})
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Large, 0))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
switch o.Driver().Type() {
|
switch {
|
||||||
case DR_MySQL:
|
case IsMysql || IsSqlite:
|
||||||
id, err := o.Raw("INSERT INTO user (user_name) VALUES (?)", names[2]).Exec()
|
res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(id, T_Large, 0))
|
if err == nil {
|
||||||
|
id, err = res.LastInsertId()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = o.Rollback()
|
err = o.Rollback()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
num, err = o.QueryTable("user").Filter("user_name__in", &user).Count()
|
num, err = o.QueryTable("tag").Filter("name__in", names).Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 0))
|
throwFail(t, AssertIs(num, T_Equal, 0))
|
||||||
|
|
||||||
err = o.Begin()
|
err = o.Begin()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
user.UserName = "commit"
|
tag.Name = "commit"
|
||||||
id, err = o.Insert(&user)
|
id, err = o.Insert(&tag)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(id, T_Large, 0))
|
throwFail(t, AssertIs(id, T_Large, 0))
|
||||||
|
|
||||||
o.Commit()
|
o.Commit()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
num, err = o.QueryTable("user").Filter("user_name", "commit").Delete()
|
num, err = o.QueryTable("tag").Filter("name", "commit").Delete()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, T_Equal, 1))
|
throwFail(t, AssertIs(num, T_Equal, 1))
|
||||||
|
|
||||||
|
11
orm/types.go
11
orm/types.go
@ -60,12 +60,12 @@ type QuerySeter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RawPreparer interface {
|
type RawPreparer interface {
|
||||||
Exec(...interface{}) (int64, error)
|
Exec(...interface{}) (sql.Result, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type RawSeter interface {
|
type RawSeter interface {
|
||||||
Exec() (int64, error)
|
Exec() (sql.Result, error)
|
||||||
QueryRow(...interface{}) error
|
QueryRow(...interface{}) error
|
||||||
QueryRows(...interface{}) (int64, error)
|
QueryRows(...interface{}) (int64, error)
|
||||||
SetArgs(...interface{}) RawSeter
|
SetArgs(...interface{}) RawSeter
|
||||||
@ -116,10 +116,15 @@ type dbBaser interface {
|
|||||||
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
|
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
|
||||||
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
|
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
|
||||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
|
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
|
||||||
|
SupportUpdateJoin() bool
|
||||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
|
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
|
||||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
|
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
|
||||||
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
|
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
|
||||||
GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
|
OperatorSql(string) string
|
||||||
|
GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
|
||||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
|
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
|
||||||
|
MaxLimit() uint64
|
||||||
|
TableQuote() string
|
||||||
|
ReplaceMarks(*string)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user