diff --git a/orm/db.go b/orm/db.go index 69908c57..fbee123d 100644 --- a/orm/db.go +++ b/orm/db.go @@ -49,28 +49,8 @@ type dbBase struct { ins dbBaser } -func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { - - fi := mi.fields.pk - - v := ind.Field(fi.fieldIndex) - if fi.fieldType&IsIntegerField > 0 { - vu := v.Int() - exist = vu > 0 - value = vu - } else { - vu := v.String() - exist = vu != "" - value = vu - } - - column = fi.column - - return -} - func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { - _, pkValue, _ := d.existPk(mi, ind) + _, pkValue, _ := getExistPk(mi, ind) for _, column := range mi.fields.orders { fi := mi.fields.columns[column] if fi.dbcol == false || fi.auto && skipAuto { @@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, if field.IsNil() { value = nil } else { - if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok { + if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { value = vu } else { value = nil @@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.ReplaceMarks(&query) + d.ins.HasReturningID(mi, &query) + stmt, err := q.Prepare(query) return stmt, query, err } @@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) return 0, err } - if res, err := stmt.Exec(values...); err == nil { - return res.LastInsertId() + if d.ins.HasReturningID(mi, nil) { + row := stmt.QueryRow(values...) + var id int64 + err := row.Scan(&id) + return id, err } else { - return 0, err + if res, err := stmt.Exec(values...); err == nil { + return res.LastInsertId() + } else { + return 0, err + } } } func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { - pkColumn, pkValue, ok := d.existPk(mi, ind) + pkColumn, pkValue, ok := getExistPk(mi, ind) if ok == false { return ErrMissPK } @@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e d.ins.ReplaceMarks(&query) - if res, err := q.Exec(query, values...); err == nil { - return res.LastInsertId() + if d.ins.HasReturningID(mi, &query) { + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err } else { - return 0, err + if res, err := q.Exec(query, values...); err == nil { + return res.LastInsertId() + } else { + return 0, err + } } } func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { - pkName, pkValue, ok := d.existPk(mi, ind) + pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { return 0, ErrMissPK } @@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e } func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { - pkName, pkValue, ok := d.existPk(mi, ind) + pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { return 0, ErrMissPK } @@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } - sql, args := d.ins.GenerateOperatorSql(mi, "in", args) + sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) d.ins.ReplaceMarks(&query) @@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } -func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) { - for _, arg := range args { - val := reflect.ValueOf(arg) - - if arg == nil { - params = append(params, arg) - continue - } - - kind := val.Kind() - - switch kind { - case reflect.Slice, reflect.Array: - var args []interface{} - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - - var vu interface{} - if v.CanInterface() { - vu = v.Interface() - } - - if vu == nil { - continue - } - - args = append(args, vu) - } - - if len(args) > 0 { - p := d.getOperatorParams(operator, args) - params = append(params, p...) - } - - case reflect.Ptr, reflect.Struct: - ind := reflect.Indirect(val) - - if ind.Kind() == reflect.Struct { - typ := ind.Type() - name := getFullName(typ) - var value interface{} - if mmi, ok := modelCache.getByFN(name); ok { - if _, vu, exist := d.existPk(mmi, ind); exist { - value = vu - } - } - arg = value - - if arg == nil { - panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name)) - } - } else { - arg = ind.Interface() - } - - params = append(params, arg) - - default: - params = append(params, arg) - } - - } - - return -} - -func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { +func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) { sql := "" - params := d.getOperatorParams(operator, args) + params := getFlatParams(fi, args) + + if len(params) == 0 { + panic(fmt.Sprintf("operator `%s` need at least one args", operator)) + } + arg := params[0] if operator == "in" { marks := make([]string, len(params)) @@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) } sql = d.ins.OperatorSql(operator) - arg := params[0] switch operator { case "exact": if arg == nil { @@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte return sql, params } +func (d *dbBase) GenerateOperatorLeftCol(string, *string) { + +} + func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() @@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond cols = make([]string, 0, len(exprs)) infos = make([]*fieldInfo, 0, len(exprs)) for _, ex := range exprs { - index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) + index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } - cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q)) + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) infos = append(infos, fi) } } else { @@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string { func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } + +func (d *dbBase) HasReturningID(*modelInfo, *string) bool { + return false +} diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 58562036..4bfc5a83 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -1,6 +1,7 @@ package orm import ( + "fmt" "strconv" ) @@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string { return postgresOperators[operator] } +func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) { + switch operator { + case "contains", "startswith", "endswith": + *leftCol = fmt.Sprintf("%s::text", *leftCol) + case "iexact", "icontains", "istartswith", "iendswith": + *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) + } +} + +func (d *dbBasePostgres) SupportUpdateJoin() bool { + return false +} + +func (d *dbBasePostgres) MaxLimit() uint64 { + return 0 +} + func (d *dbBasePostgres) TableQuote() string { return `"` } @@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { *query = string(data) } -// func (d *dbBasePostgres) +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { + if mi.fields.pk.auto { + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) + } + has = true + } + return +} func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) diff --git a/orm/db_tables.go b/orm/db_tables.go index 009ea58f..5cf567cc 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) { return } -func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { +func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { var ( ffi *fieldInfo jtl *dbTable @@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam } else { index = jtl.index } - column = fi.column info = fi if jtl != nil { name = jtl.name + ExprSep + fi.name @@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam if exist == false { index = "" - column = "" name = "" + info = nil success = false return } } - success = index != "" && column != "" + success = index != "" && info != nil return } @@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ exprs = exprs[:num] } - index, column, _, _, suc := d.parseExprs(mi, exprs) + index, _, fi, suc := d.parseExprs(mi, exprs) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) } @@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ operator = "exact" } - operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args) + operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args) - where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql) + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + d.base.GenerateOperatorLeftCol(operator, &leftCol) + + where += fmt.Sprintf("%s %s ", leftCol, operSql) params = append(params, args...) } @@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { } exprs := strings.Split(order, ExprSep) - index, column, _, _, suc := d.parseExprs(d.mi, exprs) + index, _, fi, suc := d.parseExprs(d.mi, exprs) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) } orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) @@ -365,7 +367,11 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s // no limit if offset > 0 { maxLimit := d.base.MaxLimit() - limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + if maxLimit == 0 { + limits = fmt.Sprintf("OFFSET %d", offset) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + } } } else if offset <= 0 { limits = fmt.Sprintf("LIMIT %d", limit) diff --git a/orm/db_utils.go b/orm/db_utils.go new file mode 100644 index 00000000..a31c5221 --- /dev/null +++ b/orm/db_utils.go @@ -0,0 +1,98 @@ +package orm + +import ( + "fmt" + "reflect" + "time" +) + +func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.fields.pk + + v := ind.Field(fi.fieldIndex) + if fi.fieldType&IsIntegerField > 0 { + vu := v.Int() + exist = vu > 0 + value = vu + } else { + vu := v.String() + exist = vu != "" + value = vu + } + + column = fi.column + return +} + +func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) { + +outFor: + for _, arg := range args { + val := reflect.ValueOf(arg) + + if arg == nil { + params = append(params, arg) + continue + } + + switch v := arg.(type) { + case []byte: + case time.Time: + if fi != nil && fi.fieldType == TypeDateField { + arg = v.Format(format_Date) + } else { + arg = v.Format(format_DateTime) + } + default: + kind := val.Kind() + switch kind { + case reflect.Slice, reflect.Array: + + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } + + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := getFlatParams(fi, args) + params = append(params, p...) + } + continue outFor + + case reflect.Ptr, reflect.Struct: + ind := reflect.Indirect(val) + + if ind.Kind() == reflect.Struct { + typ := ind.Type() + name := getFullName(typ) + var value interface{} + if mmi, ok := modelCache.getByFN(name); ok { + if _, vu, exist := getExistPk(mmi, ind); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) + } + } else { + arg = ind.Interface() + } + } + } + params = append(params, arg) + } + return +} diff --git a/orm/models_test.go b/orm/models_test.go index fb46fb01..eca35553 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a queries := strings.Split(initSQLs[DBARGS.Driver], ";") for _, query := range queries { - if strings.TrimSpace(query) == "" { + query = strings.TrimSpace(query) + if len(query) == 0 { continue } _, err := dORM.Raw(query).Exec() diff --git a/orm/orm_log.go b/orm/orm_log.go index 20d2ed90..0bb5d6f9 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log { func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { sub := time.Now().Sub(t) / 1e5 elsp := float64(int(sub)) / 10.0 - con := fmt.Sprintf(" - %s - [Queries/%s] - [%11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, operaton, elsp, query) + flag := " OK" + if err != nil { + flag = "FAIL" + } + con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query) cons := make([]string, 0, len(args)) for _, arg := range args { cons = append(cons, fmt.Sprintf("%v", arg)) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 4a7ee998..669a354f 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error { func newRawPreparer(rs *rawSet) (RawPreparer, error) { o := new(rawPrepare) o.rs = rs - st, err := rs.orm.db.Prepare(rs.query) + + query := rs.query + rs.orm.alias.DbBaser.ReplaceMarks(&query) + + st, err := rs.orm.db.Prepare(query) if err != nil { return nil, err } if Debug { - o.stmt = newStmtQueryLog(rs.orm.alias, st, rs.query) + o.stmt = newStmtQueryLog(rs.orm.alias, st, query) } else { o.stmt = st } @@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { } func (o *rawSet) Exec() (sql.Result, error) { - return o.orm.db.Exec(o.query, o.args...) + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args) + return o.orm.db.Exec(query, args...) } func (o *rawSet) QueryRow(...interface{}) error { @@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { panic(fmt.Sprintf("unsupport read values type `%T`", container)) } + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args) + var rs *sql.Rows - if r, err := o.orm.db.Query(o.query, o.args...); err != nil { + if r, err := o.orm.db.Query(query, args...); err != nil { return 0, err } else { rs = r diff --git a/orm/orm_test.go b/orm/orm_test.go index a682934d..ad804114 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e if v2, vo := b.(time.Time); vo { if arg.Get(1) != nil { format := ToStr(arg.Get(1)) - ok = v.Format(format) == v2.Format(format) + a = v.Format(format) + b = v2.Format(format) + ok = a == b } else { err = fmt.Errorf("compare datetime miss format") goto wrongArg @@ -363,6 +365,10 @@ func TestExpr(t *testing.T) { num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) throwFail(t, AssertIs(num, T_Equal, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) } func TestOperators(t *testing.T) { @@ -722,6 +728,102 @@ func TestRaw(t *testing.T) { throwFail(t, AssertIs(list[1], T_Equal, "3")) throwFail(t, AssertIs(list[2], T_Equal, "")) } + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid, T_Large, 0)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, T_Equal, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) + } + + case IsPostgres: + + res, err := dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, T_Equal, 1), err) + + res, err = dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, T_Equal, 1), err) + + var maps []Params + num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], T_Equal, "slene")) + } + + var list ParamsList + num, err = dORM.Raw(`SELECT "profile_id" FROM "user" ORDER BY id ASC`).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], T_Equal, "2")) + throwFail(t, AssertIs(list[1], T_Equal, "3")) + throwFail(t, AssertIs(list[2], T_Equal, "")) + } + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, T_Equal, 3)) + } } } diff --git a/orm/types.go b/orm/types.go index 0f2cd6b6..5c6538f5 100644 --- a/orm/types.go +++ b/orm/types.go @@ -121,10 +121,12 @@ type dbBaser interface { DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) OperatorSql(string) string - GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) + GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{}) + GenerateOperatorLeftCol(string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) + HasReturningID(*modelInfo, *string) bool }