diff --git a/orm/db.go b/orm/db.go index 8f47e235..69908c57 100644 --- a/orm/db.go +++ b/orm/db.go @@ -43,395 +43,8 @@ var ( "isnull": true, // "search": true, } - operatorsSQL = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ?", - "contains": "LIKE BINARY ?", - "icontains": "LIKE ?", - // "regex": "REGEXP BINARY ?", - // "iregex": "REGEXP ?", - "gt": "> ?", - "gte": ">= ?", - "lt": "< ?", - "lte": "<= ?", - "startswith": "LIKE BINARY ?", - "endswith": "LIKE BINARY ?", - "istartswith": "LIKE ?", - "iendswith": "LIKE ?", - } ) -type dbTable struct { - id int - index string - name string - names []string - sel bool - inner bool - mi *modelInfo - fi *fieldInfo - jtl *dbTable -} - -type dbTables struct { - tablesM map[string]*dbTable - tables []*dbTable - mi *modelInfo - base dbBaser -} - -func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { - name := strings.Join(names, ExprSep) - if j, ok := t.tablesM[name]; ok { - j.name = name - j.mi = mi - j.fi = fi - j.inner = inner - } else { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - } - return t.tablesM[name] -} - -func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { - name := strings.Join(names, ExprSep) - if _, ok := t.tablesM[name]; ok == false { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - return jt, true - } - return t.tablesM[name], false -} - -func (t *dbTables) get(name string) (*dbTable, bool) { - j, ok := t.tablesM[name] - return j, ok -} - -func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { - if depth < 0 || fi.fieldType == RelManyToMany { - return related - } - - if prefix == "" { - prefix = fi.name - } else { - prefix = prefix + ExprSep + fi.name - } - related = append(related, prefix) - - depth-- - for _, fi := range fi.relModelInfo.fields.fieldsRel { - related = t.loopDepth(depth, prefix, fi, related) - } - - return related -} - -func (t *dbTables) parseRelated(rels []string, depth int) { - - relsNum := len(rels) - related := make([]string, relsNum) - copy(related, rels) - - relDepth := depth - - if relsNum != 0 { - relDepth = 0 - } - - relDepth-- - for _, fi := range t.mi.fields.fieldsRel { - related = t.loopDepth(relDepth, "", fi, related) - } - - for i, s := range related { - var ( - exs = strings.Split(s, ExprSep) - names = make([]string, 0, len(exs)) - mmi = t.mi - cansel = true - jtl *dbTable - ) - for _, ex := range exs { - if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { - names = append(names, fi.name) - mmi = fi.relModelInfo - - jt := t.set(names, mmi, fi, fi.null == false) - jt.jtl = jtl - - if fi.reverse { - cansel = false - } - - if cansel { - jt.sel = depth > 0 - - if i < relsNum { - jt.sel = true - } - } - - jtl = jt - - } else { - panic(fmt.Sprintf("unknown model/table name `%s`", ex)) - } - } - } -} - -func (t *dbTables) getJoinSql() (join string) { - for _, jt := range t.tables { - if jt.inner { - join += "INNER JOIN " - } else { - join += "LEFT OUTER JOIN " - } - var ( - table string - t1, t2 string - c1, c2 string - ) - t1 = "T0" - if jt.jtl != nil { - t1 = jt.jtl.index - } - t2 = jt.index - table = jt.mi.table - - switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: - c1 = jt.fi.mi.fields.pk.column - for _, ffi := range jt.mi.fields.fieldsRel { - if jt.fi.mi == ffi.relModelInfo { - c2 = ffi.column - break - } - } - default: - c1 = jt.fi.column - c2 = jt.fi.relModelInfo.fields.pk.column - - if jt.fi.reverse { - c1 = jt.mi.fields.pk.column - c2 = jt.fi.reverseFieldInfo.column - } - } - - join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2, - t2, c2, t1, c1) - } - return -} - -func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { - var ( - ffi *fieldInfo - jtl *dbTable - mmi = mi - ) - - num := len(exprs) - 1 - names := make([]string, 0) - - for i, ex := range exprs { - exist := false - - check: - fi, ok := mmi.fields.GetByAny(ex) - - if ok { - - if num != i { - names = append(names, fi.name) - - switch { - case fi.rel: - mmi = fi.relModelInfo - if fi.fieldType == RelManyToMany { - mmi = fi.relThroughModelInfo - } - case fi.reverse: - mmi = fi.reverseFieldInfo.mi - if fi.reverseFieldInfo.fieldType == RelManyToMany { - mmi = fi.reverseFieldInfo.relThroughModelInfo - } - default: - return - } - - jt, _ := d.add(names, mmi, fi, fi.null == false) - jt.jtl = jtl - jtl = jt - - if fi.rel && fi.fieldType == RelManyToMany { - ex = fi.relModelInfo.name - goto check - } - - if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany { - ex = fi.reverseFieldInfo.mi.name - goto check - } - - exist = true - - } else { - - if ffi == nil { - index = "T0" - } else { - index = jtl.index - } - column = fi.column - info = fi - if jtl != nil { - name = jtl.name + ExprSep + fi.name - } else { - name = fi.name - } - - switch fi.fieldType { - case RelManyToMany, RelReverseMany: - default: - exist = true - } - } - - ffi = fi - } - - if exist == false { - index = "" - column = "" - name = "" - success = false - return - } - } - - success = index != "" && column != "" - return -} - -func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { - if cond == nil || cond.IsEmpty() { - return - } - - mi := d.mi - - // outFor: - for i, p := range cond.params { - if i > 0 { - if p.isOr { - where += "OR " - } else { - where += "AND " - } - } - if p.isNot { - where += "NOT " - } - if p.isCond { - w, ps := d.getCondSql(p.cond, true) - if w != "" { - w = fmt.Sprintf("( %s) ", w) - } - where += w - params = append(params, ps...) - } else { - exprs := p.exprs - - num := len(exprs) - 1 - operator := "" - if operators[exprs[num]] { - operator = exprs[num] - exprs = exprs[:num] - } - - index, column, _, _, suc := d.parseExprs(mi, exprs) - if suc == false { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) - } - - if operator == "" { - operator = "exact" - } - - operSql, args := d.base.GetOperatorSql(mi, operator, p.args) - - where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql) - params = append(params, args...) - - } - } - - if sub == false && where != "" { - where = "WHERE " + where - } - - return -} - -func (d *dbTables) getOrderSql(orders []string) (orderSql string) { - if len(orders) == 0 { - return - } - - orderSqls := make([]string, 0, len(orders)) - for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) - - index, column, _, _, suc := d.parseExprs(d.mi, exprs) - if suc == false { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } - - orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc)) - } - - orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) - return -} - -func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) { - if limit == 0 { - limit = DefaultRowsLimit - } - if limit < 0 { - // no limit - if offset > 0 { - limits = fmt.Sprintf("LIMIT 18446744073709551615 OFFSET %d", offset) - } - } else if offset <= 0 { - limits = fmt.Sprintf("LIMIT %d", limit) - } else { - limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) - } - return -} - -func newDbTables(mi *modelInfo, base dbBaser) *dbTables { - tables := &dbTables{} - tables.tablesM = make(map[string]*dbTable) - tables.mi = mi - tables.base = base - return tables -} - type dbBase struct { ins dbBaser } @@ -528,6 +141,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, } func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { + Q := d.ins.TableQuote() + dbcols := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.fields.dbcols)) for _, fi := range mi.fields.fieldsDB { @@ -537,9 +152,13 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, } } qmarks := strings.Join(marks, ", ") - columns := strings.Join(dbcols, "`,`") + sep := fmt.Sprintf("%s, %s", Q, Q) + columns := strings.Join(dbcols, sep) + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) - query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) stmt, err := q.Prepare(query) return stmt, query, err } @@ -563,10 +182,13 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { return ErrMissPK } - sels := strings.Join(mi.fields.dbcols, "`, `") + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s, %s", Q, Q) + sels := strings.Join(mi.fields.dbcols, sep) colsNum := len(mi.fields.dbcols) - query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn) + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q) refs := make([]interface{}, colsNum) for i, _ := range refs { @@ -574,6 +196,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { refs[i] = &ref } + d.ins.ReplaceMarks(&query) + row := q.QueryRow(query, pkValue) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { @@ -598,14 +222,20 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e return 0, err } + Q := d.ins.TableQuote() + marks := make([]string, len(names)) for i, _ := range marks { marks[i] = "?" } - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, "`,`") - query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) if res, err := q.Exec(query, values...); err == nil { return res.LastInsertId() @@ -624,12 +254,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e return 0, err } - setColumns := strings.Join(setNames, "` = ?, `") - - query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName) - setValues = append(setValues, pkValue) + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ?, %s", Q, Q) + setColumns := strings.Join(setNames, sep) + + query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + + d.ins.ReplaceMarks(&query) + if res, err := q.Exec(query, setValues...); err == nil { return res.RowsAffected() } else { @@ -644,7 +279,11 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e return 0, ErrMissPK } - query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName) + Q := d.ins.TableQuote() + + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q) + + d.ins.ReplaceMarks(&query) if res, err := q.Exec(query, pkValue); err == nil { @@ -694,11 +333,24 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con where, args := tables.getCondSql(cond, false) + values = append(values, args...) + join := tables.getJoinSql() - query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where) + var query string - values = append(values, args...) + Q := d.ins.TableQuote() + + if d.ins.SupportUpdateJoin() { + cols := strings.Join(columns, fmt.Sprintf("%s = ?, T0.%s", Q, Q)) + query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET T0.%s%s%s = ? %s", Q, mi.table, Q, join, Q, cols, Q, where) + } else { + cols := strings.Join(columns, fmt.Sprintf("%s = ?, %s", Q, Q)) + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + query = fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s IN ( %s )", Q, mi.table, Q, Q, cols, Q, Q, mi.fields.pk.column, Q, supQuery) + } + + d.ins.ReplaceMarks(&query) if res, err := q.Exec(query, values...); err == nil { return res.RowsAffected() @@ -744,11 +396,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con panic("delete operation cannot execute without condition") } + Q := d.ins.TableQuote() + where, args := tables.getCondSql(cond, false) join := tables.getJoinSql() - cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column) - query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where) + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + + d.ins.ReplaceMarks(&query) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { @@ -773,8 +429,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } - sql, args := d.ins.GetOperatorSql(mi, "in", args) - query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql) + sql, args := d.ins.GenerateOperatorSql(mi, "in", args) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) + + d.ins.ReplaceMarks(&query) if res, err := q.Exec(query, args...); err == nil { num, err := res.RowsAffected() @@ -831,24 +489,30 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi offset = 0 } + Q := d.ins.TableQuote() + tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) where, args := tables.getCondSql(cond, false) orderBy := tables.getOrderSql(qs.orders) - limit := tables.getLimitSql(offset, rlimit) + limit := tables.getLimitSql(mi, offset, rlimit) join := tables.getJoinSql() colsNum := len(mi.fields.dbcols) - cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`")) + sep := fmt.Sprintf("%s, T0.%s", Q, Q) + cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q) for _, tbl := range tables.tables { if tbl.sel { colsNum += len(tbl.mi.fields.dbcols) - cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`")) + sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) + cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) } } - query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit) + + d.ins.ReplaceMarks(&query) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { @@ -940,7 +604,11 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition tables.getOrderSql(qs.orders) join := tables.getJoinSql() - query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where) + Q := d.ins.TableQuote() + + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where) + + d.ins.ReplaceMarks(&query) row := q.QueryRow(query, args...) @@ -1014,7 +682,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params return } -func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { +func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { sql := "" params := d.getOperatorParams(operator, args) @@ -1028,7 +696,7 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface if len(params) > 1 { panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) } - sql = operatorsSQL[operator] + sql = d.ins.OperatorSql(operator) arg := params[0] switch operator { case "exact": @@ -1073,13 +741,13 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, value, err := d.getValue(fi, val) if err != nil { - panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) } _, err = d.setValue(fi, value, &field) if err != nil { - panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error())) } } } @@ -1090,6 +758,7 @@ func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { } var value interface{} + var tErr error var str *StrTo switch v := val.(type) { @@ -1119,7 +788,8 @@ setValue: if str != nil { b, err := str.Bool() if err != nil { - return nil, err + tErr = err + goto end } value = b } @@ -1140,14 +810,23 @@ setValue: } } if str != nil { - format := format_DateTime + s := str.String() + var format string if fi.fieldType == TypeDateField { format = format_Date + if len(s) > 10 { + s = s[:10] + } + } else { + format = format_DateTime + if len(s) > 19 { + s = s[:19] + } } - s := str.String() t, err := timeParse(s, format) if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { - return nil, err + tErr = err + goto end } value = t } @@ -1173,7 +852,8 @@ setValue: _, err = str.Uint64() } if err != nil { - return nil, err + tErr = err + goto end } if fieldType&IsPostiveIntegerField > 0 { v, _ := str.Uint64() @@ -1196,15 +876,23 @@ setValue: if str != nil { v, err := str.Float64() if err != nil { - return nil, err + tErr = err + goto end } value = v } case fieldType&IsRelField > 0: - fieldType = fi.relModelInfo.fields.pk.fieldType + fi = fi.relModelInfo.fields.pk + fieldType = fi.fieldType goto setValue } +end: + if tErr != nil { + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + return nil, err + } + return value, nil } @@ -1275,6 +963,7 @@ setValue: fd := field.Addr().Interface().(Fielder) err := fd.SetRaw(value) if err != nil { + err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) return nil, err } } @@ -1311,6 +1000,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond hasExprs := len(exprs) > 0 + Q := d.ins.TableQuote() + if hasExprs { cols = make([]string, 0, len(exprs)) infos = make([]*fieldInfo, 0, len(exprs)) @@ -1319,26 +1010,26 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } - cols = append(cols, fmt.Sprintf("%s.`%s` `%s`", index, col, name)) + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q)) infos = append(infos, fi) } } else { cols = make([]string, 0, len(mi.fields.dbcols)) infos = make([]*fieldInfo, 0, len(exprs)) for _, fi := range mi.fields.fieldsDB { - cols = append(cols, fmt.Sprintf("T0.`%s` `%s`", fi.column, fi.name)) + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) infos = append(infos, fi) } } where, args := tables.getCondSql(cond, false) orderBy := tables.getOrderSql(qs.orders) - limit := tables.getLimitSql(qs.offset, qs.limit) + limit := tables.getLimitSql(mi, qs.offset, qs.limit) join := tables.getJoinSql() sels := strings.Join(cols, ", ") - query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { @@ -1430,3 +1121,19 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } + +func (d *dbBase) SupportUpdateJoin() bool { + return true +} + +func (d *dbBase) MaxLimit() uint64 { + return 18446744073709551615 +} + +func (d *dbBase) TableQuote() string { + return "`" +} + +func (d *dbBase) ReplaceMarks(query *string) { + // default use `?` as mark, do nothing +} diff --git a/orm/db_mysql.go b/orm/db_mysql.go index c7cacd90..b6c3a118 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -1,11 +1,30 @@ package orm +var mysqlOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", + // "regex": "REGEXP BINARY ?", + // "iregex": "REGEXP ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "startswith": "LIKE BINARY ?", + "endswith": "LIKE BINARY ?", + "istartswith": "LIKE ?", + "iendswith": "LIKE ?", +} + type dbBaseMysql struct { dbBase } -func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) { - return d.dbBase.GetOperatorSql(mi, operator, args) +var _ dbBaser = new(dbBaseMysql) + +func (d *dbBaseMysql) OperatorSql(operator string) string { + return mysqlOperators[operator] } func newdbBaseMysql() dbBaser { diff --git a/orm/db_oracle.go b/orm/db_oracle.go index b5a27cad..3769e07c 100644 --- a/orm/db_oracle.go +++ b/orm/db_oracle.go @@ -4,6 +4,12 @@ type dbBaseOracle struct { dbBase } +var _ dbBaser = new(dbBaseOracle) + +func (d *dbBase) OperatorSql(operator string) string { + return "" +} + func newdbBaseOracle() dbBaser { b := new(dbBaseOracle) b.ins = b diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 1a8a2e3a..58562036 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -1,9 +1,66 @@ package orm +import ( + "strconv" +) + +var postgresOperators = map[string]string{ + "exact": "= ?", + "iexact": "= UPPER(?)", + "contains": "LIKE ?", + "icontains": "LIKE UPPER(?)", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "startswith": "LIKE ?", + "endswith": "LIKE ?", + "istartswith": "LIKE UPPER(?)", + "iendswith": "LIKE UPPER(?)", +} + type dbBasePostgres struct { dbBase } +var _ dbBaser = new(dbBasePostgres) + +func (d *dbBasePostgres) OperatorSql(operator string) string { + return postgresOperators[operator] +} + +func (d *dbBasePostgres) TableQuote() string { + return `"` +} + +func (d *dbBasePostgres) ReplaceMarks(query *string) { + q := *query + num := 0 + for _, c := range q { + if c == '?' { + num += 1 + } + } + if num == 0 { + return + } + data := make([]byte, 0, len(q)+num) + num = 1 + for i := 0; i < len(q); i++ { + c := q[i] + if c == '?' { + data = append(data, '$') + data = append(data, []byte(strconv.Itoa(num))...) + num += 1 + } else { + data = append(data, c) + } + } + *query = string(data) +} + +// func (d *dbBasePostgres) + func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) b.ins = b diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index c3c0322e..e5a234e5 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -1,9 +1,38 @@ package orm +var sqliteOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ? ESCAPE '\\'", + "contains": "LIKE ? ESCAPE '\\'", + "icontains": "LIKE ? ESCAPE '\\'", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "startswith": "LIKE ? ESCAPE '\\'", + "endswith": "LIKE ? ESCAPE '\\'", + "istartswith": "LIKE ? ESCAPE '\\'", + "iendswith": "LIKE ? ESCAPE '\\'", +} + type dbBaseSqlite struct { dbBase } +var _ dbBaser = new(dbBaseSqlite) + +func (d *dbBaseSqlite) OperatorSql(operator string) string { + return sqliteOperators[operator] +} + +func (d *dbBaseSqlite) SupportUpdateJoin() bool { + return false +} + +func (d *dbBaseSqlite) MaxLimit() uint64 { + return 9223372036854775807 +} + func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) b.ins = b diff --git a/orm/db_tables.go b/orm/db_tables.go new file mode 100644 index 00000000..009ea58f --- /dev/null +++ b/orm/db_tables.go @@ -0,0 +1,384 @@ +package orm + +import ( + "fmt" + "strings" +) + +type dbTable struct { + id int + index string + name string + names []string + sel bool + inner bool + mi *modelInfo + fi *fieldInfo + jtl *dbTable +} + +type dbTables struct { + tablesM map[string]*dbTable + tables []*dbTable + mi *modelInfo + base dbBaser +} + +func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { + name := strings.Join(names, ExprSep) + if j, ok := t.tablesM[name]; ok { + j.name = name + j.mi = mi + j.fi = fi + j.inner = inner + } else { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + } + return t.tablesM[name] +} + +func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { + name := strings.Join(names, ExprSep) + if _, ok := t.tablesM[name]; ok == false { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + return jt, true + } + return t.tablesM[name], false +} + +func (t *dbTables) get(name string) (*dbTable, bool) { + j, ok := t.tablesM[name] + return j, ok +} + +func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { + if depth < 0 || fi.fieldType == RelManyToMany { + return related + } + + if prefix == "" { + prefix = fi.name + } else { + prefix = prefix + ExprSep + fi.name + } + related = append(related, prefix) + + depth-- + for _, fi := range fi.relModelInfo.fields.fieldsRel { + related = t.loopDepth(depth, prefix, fi, related) + } + + return related +} + +func (t *dbTables) parseRelated(rels []string, depth int) { + + relsNum := len(rels) + related := make([]string, relsNum) + copy(related, rels) + + relDepth := depth + + if relsNum != 0 { + relDepth = 0 + } + + relDepth-- + for _, fi := range t.mi.fields.fieldsRel { + related = t.loopDepth(relDepth, "", fi, related) + } + + for i, s := range related { + var ( + exs = strings.Split(s, ExprSep) + names = make([]string, 0, len(exs)) + mmi = t.mi + cansel = true + jtl *dbTable + ) + for _, ex := range exs { + if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { + names = append(names, fi.name) + mmi = fi.relModelInfo + + jt := t.set(names, mmi, fi, fi.null == false) + jt.jtl = jtl + + if fi.reverse { + cansel = false + } + + if cansel { + jt.sel = depth > 0 + + if i < relsNum { + jt.sel = true + } + } + + jtl = jt + + } else { + panic(fmt.Sprintf("unknown model/table name `%s`", ex)) + } + } + } +} + +func (t *dbTables) getJoinSql() (join string) { + Q := t.base.TableQuote() + + for _, jt := range t.tables { + if jt.inner { + join += "INNER JOIN " + } else { + join += "LEFT OUTER JOIN " + } + var ( + table string + t1, t2 string + c1, c2 string + ) + t1 = "T0" + if jt.jtl != nil { + t1 = jt.jtl.index + } + t2 = jt.index + table = jt.mi.table + + switch { + case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + c1 = jt.fi.mi.fields.pk.column + for _, ffi := range jt.mi.fields.fieldsRel { + if jt.fi.mi == ffi.relModelInfo { + c2 = ffi.column + break + } + } + default: + c1 = jt.fi.column + c2 = jt.fi.relModelInfo.fields.pk.column + + if jt.fi.reverse { + c1 = jt.mi.fields.pk.column + c2 = jt.fi.reverseFieldInfo.column + } + } + + join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, + t2, Q, c2, Q, t1, Q, c1, Q) + } + return +} + +func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { + var ( + ffi *fieldInfo + jtl *dbTable + mmi = mi + ) + + num := len(exprs) - 1 + names := make([]string, 0) + + for i, ex := range exprs { + exist := false + + check: + fi, ok := mmi.fields.GetByAny(ex) + + if ok { + + if num != i { + names = append(names, fi.name) + + switch { + case fi.rel: + mmi = fi.relModelInfo + if fi.fieldType == RelManyToMany { + mmi = fi.relThroughModelInfo + } + case fi.reverse: + mmi = fi.reverseFieldInfo.mi + if fi.reverseFieldInfo.fieldType == RelManyToMany { + mmi = fi.reverseFieldInfo.relThroughModelInfo + } + default: + return + } + + jt, _ := d.add(names, mmi, fi, fi.null == false) + jt.jtl = jtl + jtl = jt + + if fi.rel && fi.fieldType == RelManyToMany { + ex = fi.relModelInfo.name + goto check + } + + if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany { + ex = fi.reverseFieldInfo.mi.name + goto check + } + + exist = true + + } else { + + if ffi == nil { + index = "T0" + } else { + index = jtl.index + } + column = fi.column + info = fi + if jtl != nil { + name = jtl.name + ExprSep + fi.name + } else { + name = fi.name + } + + switch fi.fieldType { + case RelManyToMany, RelReverseMany: + default: + exist = true + } + } + + ffi = fi + } + + if exist == false { + index = "" + column = "" + name = "" + success = false + return + } + } + + success = index != "" && column != "" + return +} + +func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { + if cond == nil || cond.IsEmpty() { + return + } + + Q := d.base.TableQuote() + + mi := d.mi + + // outFor: + for i, p := range cond.params { + if i > 0 { + if p.isOr { + where += "OR " + } else { + where += "AND " + } + } + if p.isNot { + where += "NOT " + } + if p.isCond { + w, ps := d.getCondSql(p.cond, true) + if w != "" { + w = fmt.Sprintf("( %s) ", w) + } + where += w + params = append(params, ps...) + } else { + exprs := p.exprs + + num := len(exprs) - 1 + operator := "" + if operators[exprs[num]] { + operator = exprs[num] + exprs = exprs[:num] + } + + index, column, _, _, suc := d.parseExprs(mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) + } + + if operator == "" { + operator = "exact" + } + + operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args) + + where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql) + params = append(params, args...) + + } + } + + if sub == false && where != "" { + where = "WHERE " + where + } + + return +} + +func (d *dbTables) getOrderSql(orders []string) (orderSql string) { + if len(orders) == 0 { + return + } + + Q := d.base.TableQuote() + + orderSqls := make([]string, 0, len(orders)) + for _, order := range orders { + asc := "ASC" + if order[0] == '-' { + asc = "DESC" + order = order[1:] + } + exprs := strings.Split(order, ExprSep) + + index, column, _, _, suc := d.parseExprs(d.mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc)) + } + + orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + return +} + +func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits string) { + if limit == 0 { + limit = DefaultRowsLimit + } + if limit < 0 { + // no limit + if offset > 0 { + maxLimit := d.base.MaxLimit() + limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + } + } else if offset <= 0 { + limits = fmt.Sprintf("LIMIT %d", limit) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) + } + return +} + +func newDbTables(mi *modelInfo, base dbBaser) *dbTables { + tables := &dbTables{} + tables.tablesM = make(map[string]*dbTable) + tables.mi = mi + tables.base = base + return tables +} diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 6737ced0..dfbad42e 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { info = new(modelInfo) info.fields = newFields() - info.table = m1.table + "_" + m2.table + "_rel" + info.table = m1.table + "_" + m2.table + "s" info.name = camelString(info.table) info.fullName = m1.pkg + "." + info.name diff --git a/orm/models_test.go b/orm/models_test.go index 78853fcb..fb46fb01 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -3,10 +3,11 @@ package orm import ( "fmt" "os" + "strings" "time" - _ "github.com/bmizerany/pq" _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -95,8 +96,178 @@ var DBARGS = struct { os.Getenv("ORM_DEBUG"), } +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" +) + var dORM Ormer +var initSQLs = map[string]string{ + "mysql": "DROP TABLE IF EXISTS `user_profile`;\n" + + "DROP TABLE IF EXISTS `user`;\n" + + "DROP TABLE IF EXISTS `post`;\n" + + "DROP TABLE IF EXISTS `tag`;\n" + + "DROP TABLE IF EXISTS `post_tags`;\n" + + "DROP TABLE IF EXISTS `comment`;\n" + + "CREATE TABLE `user_profile` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `age` smallint NOT NULL,\n" + + " `money` double precision NOT NULL\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `user` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `user_name` varchar(30) NOT NULL UNIQUE,\n" + + " `email` varchar(100) NOT NULL,\n" + + " `password` varchar(100) NOT NULL,\n" + + " `status` smallint NOT NULL,\n" + + " `is_staff` bool NOT NULL,\n" + + " `is_active` bool NOT NULL,\n" + + " `created` date NOT NULL,\n" + + " `updated` datetime NOT NULL,\n" + + " `profile_id` integer\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `post` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `user_id` integer NOT NULL,\n" + + " `title` varchar(60) NOT NULL,\n" + + " `content` longtext NOT NULL,\n" + + " `created` datetime NOT NULL,\n" + + " `updated` datetime NOT NULL\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `tag` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `name` varchar(30) NOT NULL\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `post_tags` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `post_id` integer NOT NULL,\n" + + " `tag_id` integer NOT NULL,\n" + + " UNIQUE (`post_id`, `tag_id`)\n" + + ") ENGINE=INNODB;\n" + + "CREATE TABLE `comment` (\n" + + " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" + + " `post_id` integer NOT NULL,\n" + + " `content` longtext NOT NULL,\n" + + " `parent_id` integer,\n" + + " `created` datetime NOT NULL\n" + + ") ENGINE=INNODB;\n" + + "CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" + + "CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" + + "CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" + + "CREATE INDEX `comment_63f17a16` ON `comment` (`parent_id`);", + + "sqlite3": ` +DROP TABLE IF EXISTS "user_profile"; +DROP TABLE IF EXISTS "user"; +DROP TABLE IF EXISTS "post"; +DROP TABLE IF EXISTS "tag"; +DROP TABLE IF EXISTS "post_tags"; +DROP TABLE IF EXISTS "comment"; +CREATE TABLE "user_profile" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "age" smallint NOT NULL, + "money" real NOT NULL +); +CREATE TABLE "user" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "user_name" varchar(30) NOT NULL UNIQUE, + "email" varchar(100) NOT NULL, + "password" varchar(100) NOT NULL, + "status" smallint NOT NULL, + "is_staff" bool NOT NULL, + "is_active" bool NOT NULL, + "created" date NOT NULL, + "updated" datetime NOT NULL, + "profile_id" integer +); +CREATE TABLE "post" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "user_id" integer NOT NULL, + "title" varchar(60) NOT NULL, + "content" text NOT NULL, + "created" datetime NOT NULL, + "updated" datetime NOT NULL +); +CREATE TABLE "tag" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "name" varchar(30) NOT NULL +); +CREATE TABLE "post_tags" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "post_id" integer NOT NULL, + "tag_id" integer NOT NULL, + UNIQUE ("post_id", "tag_id") +); +CREATE TABLE "comment" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "post_id" integer NOT NULL, + "content" text NOT NULL, + "parent_id" integer, + "created" datetime NOT NULL +); +CREATE INDEX "user_141c6eec" ON "user" ("profile_id"); +CREATE INDEX "post_fbfc09f1" ON "post" ("user_id"); +CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id"); +CREATE INDEX "comment_63f17a16" ON "comment" ("parent_id"); +`, + + "postgres": ` +DROP TABLE IF EXISTS "user_profile"; +DROP TABLE IF EXISTS "user"; +DROP TABLE IF EXISTS "post"; +DROP TABLE IF EXISTS "tag"; +DROP TABLE IF EXISTS "post_tags"; +DROP TABLE IF EXISTS "comment"; +CREATE TABLE "user_profile" ( + "id" serial NOT NULL PRIMARY KEY, + "age" smallint NOT NULL, + "money" double precision NOT NULL +); +CREATE TABLE "user" ( + "id" serial NOT NULL PRIMARY KEY, + "user_name" varchar(30) NOT NULL UNIQUE, + "email" varchar(100) NOT NULL, + "password" varchar(100) NOT NULL, + "status" smallint NOT NULL, + "is_staff" boolean NOT NULL, + "is_active" boolean NOT NULL, + "created" date NOT NULL, + "updated" timestamp with time zone NOT NULL, + "profile_id" integer +); +CREATE TABLE "post" ( + "id" serial NOT NULL PRIMARY KEY, + "user_id" integer NOT NULL, + "title" varchar(60) NOT NULL, + "content" text NOT NULL, + "created" timestamp with time zone NOT NULL, + "updated" timestamp with time zone NOT NULL +); +CREATE TABLE "tag" ( + "id" serial NOT NULL PRIMARY KEY, + "name" varchar(30) NOT NULL +); +CREATE TABLE "post_tags" ( + "id" serial NOT NULL PRIMARY KEY, + "post_id" integer NOT NULL, + "tag_id" integer NOT NULL, + UNIQUE ("post_id", "tag_id") +); +CREATE TABLE "comment" ( + "id" serial NOT NULL PRIMARY KEY, + "post_id" integer NOT NULL, + "content" text NOT NULL, + "parent_id" integer, + "created" timestamp with time zone NOT NULL +); +CREATE INDEX "user_profile_id" ON "user" ("profile_id"); +CREATE INDEX "post_user_id" ON "post" ("user_id"); +CREATE INDEX "comment_post_id" ON "comment" ("post_id"); +CREATE INDEX "comment_parent_id" ON "comment" ("parent_id"); +`} + func init() { RegisterModel(new(User)) RegisterModel(new(Profile)) @@ -114,7 +285,7 @@ Default DB Drivers. driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 -postgres: https://github.com/bmizerany/pq +postgres: https://github.com/lib/pq eg: mysql ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm @@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a BootStrap() - truncateTables() - dORM = NewOrm() -} -func truncateTables() { - logs := "truncate tables for test\n" - o := NewOrm() - for _, m := range modelCache.allOrdered() { - query := fmt.Sprintf("truncate table `%s`", m.table) - _, err := o.Raw(query).Exec() - logs += query + "\n" + queries := strings.Split(initSQLs[DBARGS.Driver], ";") + + for _, query := range queries { + if strings.TrimSpace(query) == "" { + continue + } + _, err := dORM.Raw(query).Exec() if err != nil { - fmt.Println(logs) fmt.Println(err) os.Exit(2) } diff --git a/orm/orm_log.go b/orm/orm_log.go index 30be1152..20d2ed90 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error { func (d *dbQueryLog) Rollback() error { a := time.Now() - err := d.db.(txEnder).Commit() + err := d.db.(txEnder).Rollback() debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) return err } diff --git a/orm/orm_raw.go b/orm/orm_raw.go index af2d62e6..4a7ee998 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -6,39 +6,17 @@ import ( "reflect" ) -func getResult(res sql.Result) (int64, error) { - if num, err := res.LastInsertId(); err != nil { - return 0, err - } else { - if num > 0 { - return num, nil - } - } - if num, err := res.RowsAffected(); err != nil { - return num, err - } else { - if num > 0 { - return num, nil - } - } - return 0, nil -} - type rawPrepare struct { rs *rawSet stmt stmtQuerier closed bool } -func (o *rawPrepare) Exec(args ...interface{}) (int64, error) { +func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { if o.closed { - return 0, ErrStmtClosed + return nil, ErrStmtClosed } - res, err := o.stmt.Exec(args...) - if err != nil { - return 0, err - } - return getResult(res) + return o.stmt.Exec(args...) } func (o *rawPrepare) Close() error { @@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { return &o } -func (o *rawSet) Exec() (int64, error) { - res, err := o.orm.db.Exec(o.query, o.args...) - if err != nil { - return 0, err - } - return getResult(res) +func (o *rawSet) Exec() (sql.Result, error) { + return o.orm.db.Exec(o.query, o.args...) } func (o *rawSet) QueryRow(...interface{}) error { diff --git a/orm/orm_test.go b/orm/orm_test.go index 119fe995..a682934d 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io/ioutil" + "os" "path/filepath" "reflect" "runtime" @@ -12,6 +13,8 @@ import ( "time" ) +var _ = os.PathSeparator + type T_Code int const ( @@ -60,9 +63,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e ok = is && ok || !is && !ok if !ok { if is { - err = fmt.Errorf("should: a == b, a = `%v`, b = `%v`", a, b) + err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a) } else { - err = fmt.Errorf("should: a != b, a = `%v`, b = `%v`", a, b) + err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a) } } case T_Less, T_Large: @@ -89,9 +92,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e ok = is && ok || !is && !ok if !ok { if is { - err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[0], f1, f2) + err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[0], f1, f2) } else { - err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[1], f1, f2) + err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[1], f1, f2) } } } @@ -122,32 +125,51 @@ func getCaller(skip int) string { fun := runtime.FuncForPC(pc) _, fn := filepath.Split(file) data, err := ioutil.ReadFile(file) - code := "" + var codes []string if err == nil { lines := bytes.Split(data, []byte{'\n'}) - code = strings.TrimSpace(string(lines[line-1])) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i]))) + if code != "" { + codes = append(codes, code) + } + } } funName := fun.Name() if i := strings.LastIndex(funName, "."); i > -1 { funName = funName[i+1:] } - return fmt.Sprintf("%s:%d: %s: %s", fn, line, funName, code) + return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n")) } func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { - params := []interface{}{"\n", getCaller(2), "\n", err, "\n"} - params = append(params, args...) - t.Error(params...) + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + con += fmt.Sprint(args...) + } + t.Error(con) t.Fail() } } func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { - params := []interface{}{"\n", getCaller(2), "\n", err, "\n"} - params = append(params, args...) - t.Error(params...) + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + con += fmt.Sprint(args...) + } + t.Error(con) t.FailNow() } } @@ -165,8 +187,8 @@ func TestCRUD(t *testing.T) { profile.Age = 30 profile.Money = 1234.12 id, err := dORM.Insert(profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 1)) user := NewUser() user.UserName = "slene" @@ -177,51 +199,53 @@ func TestCRUD(t *testing.T) { user.IsActive = true id, err = dORM.Insert(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 1)) u := &User{Id: user.Id} err = dORM.Read(u) - throwFailNow(t, err) + throwFail(t, err) - throwFailNow(t, AssertIs(u.UserName, T_Equal, "slene")) - throwFailNow(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com")) - throwFailNow(t, AssertIs(u.Password, T_Equal, "pass")) - throwFailNow(t, AssertIs(u.Status, T_Equal, 3)) - throwFailNow(t, AssertIs(u.IsStaff, T_Equal, true)) - throwFailNow(t, AssertIs(u.IsActive, T_Equal, true)) - throwFailNow(t, AssertIs(u.Created, T_Equal, user.Created, format_Date)) - throwFailNow(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime)) + throwFail(t, AssertIs(u.UserName, T_Equal, "slene")) + throwFail(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, T_Equal, "pass")) + throwFail(t, AssertIs(u.Status, T_Equal, 3)) + throwFail(t, AssertIs(u.IsStaff, T_Equal, true)) + throwFail(t, AssertIs(u.IsActive, T_Equal, true)) + throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date)) + throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime)) user.UserName = "astaxie" user.Profile = profile num, err := dORM.Update(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 1)) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) u = &User{Id: user.Id} err = dORM.Read(u) - throwFailNow(t, err) + throwFail(t, err) - throwFailNow(t, AssertIs(u.UserName, T_Equal, "astaxie")) - throwFailNow(t, AssertIs(u.Profile.Id, T_Equal, profile.Id)) + if err == nil { + throwFail(t, AssertIs(u.UserName, T_Equal, "astaxie")) + throwFail(t, AssertIs(u.Profile.Id, T_Equal, profile.Id)) + } num, err = dORM.Delete(profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 1)) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) u = &User{Id: user.Id} err = dORM.Read(u) - throwFailNow(t, err) - throwFailNow(t, AssertIs(true, T_Equal, u.Profile == nil)) + throwFail(t, err) + throwFail(t, AssertIs(true, T_Equal, u.Profile == nil)) num, err = dORM.Delete(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, T_Equal, 1)) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) u = &User{Id: 100} err = dORM.Read(u) - throwFailNow(t, AssertIs(err, T_Equal, ErrNoRows)) + throwFail(t, AssertIs(err, T_Equal, ErrNoRows)) } func TestInsertTestData(t *testing.T) { @@ -232,8 +256,8 @@ func TestInsertTestData(t *testing.T) { profile.Money = 1234.12 id, err := dORM.Insert(profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 2)) user := NewUser() user.UserName = "slene" @@ -247,16 +271,16 @@ func TestInsertTestData(t *testing.T) { users = append(users, user) id, err = dORM.Insert(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 2)) profile = NewProfile() profile.Age = 30 profile.Money = 4321.09 id, err = dORM.Insert(profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 3)) user = NewUser() user.UserName = "astaxie" @@ -270,8 +294,8 @@ func TestInsertTestData(t *testing.T) { users = append(users, user) id, err = dORM.Insert(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 3)) user = NewUser() user.UserName = "nobody" @@ -284,8 +308,8 @@ func TestInsertTestData(t *testing.T) { users = append(users, user) id, err = dORM.Insert(user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, 4)) tags := []*Tag{ &Tag{Name: "golang"}, @@ -315,21 +339,21 @@ The program—and web server—godoc processes Go source files to extract docume for _, tag := range tags { id, err := dORM.Insert(tag) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) } for _, post := range posts { id, err := dORM.Insert(post) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) // dORM.M2mAdd(post, "tags", post.Tags) } for _, comment := range comments { id, err := dORM.Insert(comment) - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, T_Large, 0)) + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) } } @@ -359,9 +383,17 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 2)) + var shouldNum int + + if IsSqlite { + shouldNum = 2 + } else { + shouldNum = 0 + } + num, err = qs.Filter("user_name__contains", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 0)) + throwFail(t, AssertIs(num, T_Equal, shouldNum)) num, err = qs.Filter("user_name__icontains", "E").Count() throwFail(t, err) @@ -391,9 +423,15 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 1)) + if IsSqlite { + shouldNum = 1 + } else { + shouldNum = 0 + } + num, err = qs.Filter("user_name__startswith", "S").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 0)) + throwFail(t, AssertIs(num, T_Equal, shouldNum)) num, err = qs.Filter("user_name__istartswith", "S").Count() throwFail(t, err) @@ -403,9 +441,15 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 2)) + if IsSqlite { + shouldNum = 2 + } else { + shouldNum = 0 + } + num, err = qs.Filter("user_name__endswith", "E").Count() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 0)) + throwFail(t, AssertIs(num, T_Equal, shouldNum)) num, err = qs.Filter("user_name__iendswith", "E").Count() throwFail(t, err) @@ -537,7 +581,6 @@ func TestRelatedSel(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 1)) throwFail(t, AssertNot(user.Profile, T_Equal, nil)) - throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28)) if user.Profile != nil { throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28)) } @@ -617,7 +660,7 @@ func TestOrderBy(t *testing.T) { func TestPrepareInsert(t *testing.T) { qs := dORM.QueryTable("user") i, err := qs.PrepareInsert() - throwFail(t, err) + throwFailNow(t, err) var user User user.UserName = "testing1" @@ -641,15 +684,18 @@ func TestPrepareInsert(t *testing.T) { } func TestRaw(t *testing.T) { - switch dORM.Driver().Type() { - case DR_MySQL: - num, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec() - throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + switch { + case IsMysql || IsSqlite: - num, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec() + res, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec() throwFail(t, err) - throwFail(t, AssertIs(num, T_Equal, 1)) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, T_Equal, 1), err) + + res, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, T_Equal, 1), err) var maps []Params num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps) @@ -681,11 +727,18 @@ func TestRaw(t *testing.T) { func TestUpdate(t *testing.T) { qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Update(Params{ + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ "is_staff": true, }) throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) } func TestDelete(t *testing.T) { @@ -701,48 +754,54 @@ func TestDelete(t *testing.T) { } func TestTransaction(t *testing.T) { + // this test worked when database support transaction + o := NewOrm() err := o.Begin() throwFail(t, err) var names = []string{"1", "2", "3"} - var user User - user.UserName = names[0] - id, err := o.Insert(&user) + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id, T_Large, 0)) - num, err := o.QueryTable("user").Filter("user_name", "slene").Update(Params{"user_name": names[1]}) + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) - throwFail(t, AssertIs(num, T_Large, 0)) + throwFail(t, AssertIs(num, T_Equal, 1)) - switch o.Driver().Type() { - case DR_MySQL: - id, err := o.Raw("INSERT INTO user (user_name) VALUES (?)", names[2]).Exec() + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) - throwFail(t, AssertIs(id, T_Large, 0)) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, T_Large, 0)) + } } err = o.Rollback() throwFail(t, err) - num, err = o.QueryTable("user").Filter("user_name__in", &user).Count() + num, err = o.QueryTable("tag").Filter("name__in", names).Count() throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 0)) err = o.Begin() throwFail(t, err) - user.UserName = "commit" - id, err = o.Insert(&user) + tag.Name = "commit" + id, err = o.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id, T_Large, 0)) o.Commit() throwFail(t, err) - num, err = o.QueryTable("user").Filter("user_name", "commit").Delete() + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 1)) diff --git a/orm/types.go b/orm/types.go index fc0c6bb3..0f2cd6b6 100644 --- a/orm/types.go +++ b/orm/types.go @@ -60,12 +60,12 @@ type QuerySeter interface { } type RawPreparer interface { - Exec(...interface{}) (int64, error) + Exec(...interface{}) (sql.Result, error) Close() error } type RawSeter interface { - Exec() (int64, error) + Exec() (sql.Result, error) QueryRow(...interface{}) error QueryRows(...interface{}) (int64, error) SetArgs(...interface{}) RawSeter @@ -116,10 +116,15 @@ type dbBaser interface { Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) + SupportUpdateJoin() bool UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) - GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) + OperatorSql(string) string + GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) + MaxLimit() uint64 + TableQuote() string + ReplaceMarks(*string) }