1
0
mirror of https://github.com/astaxie/beego.git synced 2025-07-02 23:40:19 +00:00

orm add postgres support

This commit is contained in:
slene
2013-08-11 22:27:45 +08:00
parent 449fbe82f6
commit 45345fa782
9 changed files with 316 additions and 122 deletions

148
orm/db.go
View File

@ -49,28 +49,8 @@ type dbBase struct {
ins dbBaser
}
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = vu > 0
value = vu
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValue, _ := d.existPk(mi, ind)
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto {
@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
if field.IsNil() {
value = nil
} else {
if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok {
value = vu
} else {
value = nil
@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.ReplaceMarks(&query)
d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query)
return stmt, query, err
}
@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value)
return 0, err
}
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
if d.ins.HasReturningID(mi, nil) {
row := stmt.QueryRow(values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
return 0, err
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
}
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
pkColumn, pkValue, ok := d.existPk(mi, ind)
pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return ErrMissPK
}
@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, values...); err == nil {
return res.LastInsertId()
if d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
return 0, err
if res, err := q.Exec(query, values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkName, pkValue, ok := d.existPk(mi, ind)
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkName, pkValue, ok := d.existPk(mi, ind)
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
sql, args := d.ins.GenerateOperatorSql(mi, "in", args)
sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args)
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return
}
func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) {
for _, arg := range args {
val := reflect.ValueOf(arg)
if arg == nil {
params = append(params, arg)
continue
}
kind := val.Kind()
switch kind {
case reflect.Slice, reflect.Array:
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := d.getOperatorParams(operator, args)
params = append(params, p...)
}
case reflect.Ptr, reflect.Struct:
ind := reflect.Indirect(val)
if ind.Kind() == reflect.Struct {
typ := ind.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := d.existPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name))
}
} else {
arg = ind.Interface()
}
params = append(params, arg)
default:
params = append(params, arg)
}
}
return
}
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) {
sql := ""
params := d.getOperatorParams(operator, args)
params := getFlatParams(fi, args)
if len(params) == 0 {
panic(fmt.Sprintf("operator `%s` need at least one args", operator))
}
arg := params[0]
if operator == "in" {
marks := make([]string, len(params))
@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = d.ins.OperatorSql(operator)
arg := params[0]
switch operator {
case "exact":
if arg == nil {
@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
return sql, params
}
func (d *dbBase) GenerateOperatorLeftCol(string, *string) {
}
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
cols = make([]string, 0, len(exprs))
infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs {
index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", ex))
}
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q))
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
infos = append(infos, fi)
}
} else {
@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string {
func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing
}
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false
}