1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 20:10:56 +00:00

orm add postgres support

This commit is contained in:
slene 2013-08-11 22:27:45 +08:00
parent 449fbe82f6
commit 45345fa782
9 changed files with 316 additions and 122 deletions

136
orm/db.go
View File

@ -49,28 +49,8 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = vu > 0
value = vu
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValue, _ := d.existPk(mi, ind)
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto {
@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if field.IsNil() {
value = nil
} else {
if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok {
value = vu
} else {
value = nil
@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.ReplaceMarks(&query)
d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query)
return stmt, query, err
}
@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value)
return 0, err
}
if d.ins.HasReturningID(mi, nil) {
row := stmt.QueryRow(values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
pkColumn, pkValue, ok := d.existPk(mi, ind)
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
}
@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := q.Exec(query, values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkName, pkValue, ok := d.existPk(mi, ind)
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkName, pkValue, ok := d.existPk(mi, ind)
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
sql, args := d.ins.GenerateOperatorSql(mi, "in", args)
sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args)
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return
}
func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) {
for _, arg := range args {
val := reflect.ValueOf(arg)
if arg == nil {
params = append(params, arg)
continue
}
kind := val.Kind()
switch kind {
case reflect.Slice, reflect.Array:
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := d.getOperatorParams(operator, args)
params = append(params, p...)
}
case reflect.Ptr, reflect.Struct:
ind := reflect.Indirect(val)
if ind.Kind() == reflect.Struct {
typ := ind.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := d.existPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name))
}
} else {
arg = ind.Interface()
}
params = append(params, arg)
default:
params = append(params, arg)
}
}
return
}
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) {
sql := ""
params := d.getOperatorParams(operator, args)
params := getFlatParams(fi, args)
if len(params) == 0 {
panic(fmt.Sprintf("operator `%s` need at least one args", operator))
}
arg := params[0]
if operator == "in" {
marks := make([]string, len(params))
@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = d.ins.OperatorSql(operator)
arg := params[0]
switch operator {
case "exact":
if arg == nil {
@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
return sql, params
}
func (d *dbBase) GenerateOperatorLeftCol(string, *string) {
}
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
cols = make([]string, 0, len(exprs))
infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs {
index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", ex))
}
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q))
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
infos = append(infos, fi)
}
} else {
@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string {
func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing
}
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false
}

View File

@ -1,6 +1,7 @@
package orm
import (
"fmt"
"strconv"
)
@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)
case "iexact", "icontains", "istartswith", "iendswith":
*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
}
}
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data)
}
// func (d *dbBasePostgres)
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto {
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column)
}
has = true
}
return
}
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)

View File

@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) {
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
} else {
index = jtl.index
}
column = fi.column
info = fi
if jtl != nil {
name = jtl.name + ExprSep + fi.name
@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
if exist == false {
index = ""
column = ""
name = ""
info = nil
success = false
return
}
}
success = index != "" && column != ""
success = index != "" && info != nil
return
}
@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
exprs = exprs[:num]
}
index, column, _, _, suc := d.parseExprs(mi, exprs)
index, _, fi, suc := d.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
operator = "exact"
}
operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args)
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args)
where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
d.base.GenerateOperatorLeftCol(operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
params = append(params, args...)
}
@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
}
exprs := strings.Split(order, ExprSep)
index, column, _, _, suc := d.parseExprs(d.mi, exprs)
index, _, fi, suc := d.parseExprs(d.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc))
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
@ -365,8 +367,12 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s
// no limit
if offset > 0 {
maxLimit := d.base.MaxLimit()
if maxLimit == 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
}
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {

98
orm/db_utils.go Normal file
View File

@ -0,0 +1,98 @@
package orm
import (
"fmt"
"reflect"
"time"
)
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = vu > 0
value = vu
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) {
outFor:
for _, arg := range args {
val := reflect.ValueOf(arg)
if arg == nil {
params = append(params, arg)
continue
}
switch v := arg.(type) {
case []byte:
case time.Time:
if fi != nil && fi.fieldType == TypeDateField {
arg = v.Format(format_Date)
} else {
arg = v.Format(format_DateTime)
}
default:
kind := val.Kind()
switch kind {
case reflect.Slice, reflect.Array:
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args)
params = append(params, p...)
}
continue outFor
case reflect.Ptr, reflect.Struct:
ind := reflect.Indirect(val)
if ind.Kind() == reflect.Struct {
typ := ind.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
} else {
arg = ind.Interface()
}
}
}
params = append(params, arg)
}
return
}

View File

@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a
queries := strings.Split(initSQLs[DBARGS.Driver], ";")
for _, query := range queries {
if strings.TrimSpace(query) == "" {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
_, err := dORM.Raw(query).Exec()

View File

@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log {
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0
con := fmt.Sprintf(" - %s - [Queries/%s] - [%11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, operaton, elsp, query)
flag := " OK"
if err != nil {
flag = "FAIL"
}
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))

View File

@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error {
func newRawPreparer(rs *rawSet) (RawPreparer, error) {
o := new(rawPrepare)
o.rs = rs
st, err := rs.orm.db.Prepare(rs.query)
query := rs.query
rs.orm.alias.DbBaser.ReplaceMarks(&query)
st, err := rs.orm.db.Prepare(query)
if err != nil {
return nil, err
}
if Debug {
o.stmt = newStmtQueryLog(rs.orm.alias, st, rs.query)
o.stmt = newStmtQueryLog(rs.orm.alias, st, query)
} else {
o.stmt = st
}
@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter {
}
func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(o.query, o.args...)
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args)
return o.orm.db.Exec(query, args...)
}
func (o *rawSet) QueryRow(...interface{}) error {
@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
panic(fmt.Sprintf("unsupport read values type `%T`", container))
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args)
var rs *sql.Rows
if r, err := o.orm.db.Query(o.query, o.args...); err != nil {
if r, err := o.orm.db.Query(query, args...); err != nil {
return 0, err
} else {
rs = r

View File

@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
if v2, vo := b.(time.Time); vo {
if arg.Get(1) != nil {
format := ToStr(arg.Get(1))
ok = v.Format(format) == v2.Format(format)
a = v.Format(format)
b = v2.Format(format)
ok = a == b
} else {
err = fmt.Errorf("compare datetime miss format")
goto wrongArg
@ -363,6 +365,10 @@ func TestExpr(t *testing.T) {
num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
num, err = qs.Filter("created", time.Now()).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
func TestOperators(t *testing.T) {
@ -722,6 +728,102 @@ func TestRaw(t *testing.T) {
throwFail(t, AssertIs(list[1], T_Equal, "3"))
throwFail(t, AssertIs(list[2], T_Equal, ""))
}
pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
throwFail(t, err)
if pre != nil {
r, err := pre.Exec("name1")
throwFail(t, err)
tid, err := r.LastInsertId()
throwFail(t, err)
throwFail(t, AssertIs(tid, T_Large, 0))
r, err = pre.Exec("name2")
throwFail(t, err)
id, err := r.LastInsertId()
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, tid+1))
r, err = pre.Exec("name3")
throwFail(t, err)
id, err = r.LastInsertId()
throwFail(t, err)
throwFail(t, AssertIs(id, T_Equal, tid+2))
err = pre.Close()
throwFail(t, err)
res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec()
throwFail(t, err)
num, err := res.RowsAffected()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
case IsPostgres:
res, err := dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "testing", "slene").Exec()
throwFail(t, err)
num, err := res.RowsAffected()
throwFail(t, AssertIs(num, T_Equal, 1), err)
res, err = dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "slene", "testing").Exec()
throwFail(t, err)
num, err = res.RowsAffected()
throwFail(t, AssertIs(num, T_Equal, 1), err)
var maps []Params
num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
if num == 1 {
throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene"))
}
var lists []ParamsList
num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).ValuesList(&lists)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 1))
if num == 1 {
throwFail(t, AssertIs(lists[0][0], T_Equal, "slene"))
}
var list ParamsList
num, err = dORM.Raw(`SELECT "profile_id" FROM "user" ORDER BY id ASC`).ValuesFlat(&list)
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
if num == 3 {
throwFail(t, AssertIs(list[0], T_Equal, "2"))
throwFail(t, AssertIs(list[1], T_Equal, "3"))
throwFail(t, AssertIs(list[2], T_Equal, ""))
}
pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare()
throwFail(t, err)
if pre != nil {
_, err := pre.Exec("name1")
throwFail(t, err)
_, err = pre.Exec("name2")
throwFail(t, err)
_, err = pre.Exec("name3")
throwFail(t, err)
err = pre.Close()
throwFail(t, err)
res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec()
throwFail(t, err)
num, err := res.RowsAffected()
throwFail(t, err)
throwFail(t, AssertIs(num, T_Equal, 3))
}
}
}

View File

@ -121,10 +121,12 @@ type dbBaser interface {
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
OperatorSql(string) string
GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{})
GenerateOperatorLeftCol(string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool
}