1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-29 12:01:28 +00:00

add context for db operation

This commit is contained in:
nlimpid 2018-11-18 21:54:25 +08:00
parent f64e6b72e9
commit e56d1b718f
4 changed files with 88 additions and 21 deletions

View File

@ -762,7 +762,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...) var err error
var res sql.Result
if qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
if err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} }
@ -851,11 +857,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
for i := range marks { for i := range marks {
marks[i] = "?" marks[i] = "?"
} }
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...) var res sql.Result
if qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
if err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
@ -978,11 +989,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
r, err := q.Query(query, args...) var err error
if qs.forContext {
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
rs = r } else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
}
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
@ -1111,8 +1129,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...) var row *sql.Row
if qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
err = row.Scan(&cnt) err = row.Scan(&cnt)
return return
} }

View File

@ -123,6 +123,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
return stmt, err return stmt, err
} }
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.PrepareContext(ctx, query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now() a := time.Now()
res, err := d.db.Exec(query, args...) res, err := d.db.Exec(query, args...)
@ -130,6 +137,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error)
return res, err return res, err
} }
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.ExecContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now() a := time.Now()
res, err := d.db.Query(query, args...) res, err := d.db.Query(query, args...)
@ -137,6 +151,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error)
return res, err return res, err
} }
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.QueryContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now() a := time.Now()
res := d.db.QueryRow(query, args...) res := d.db.QueryRow(query, args...)
@ -144,6 +165,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
return res return res
} }
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRowContext(ctx, query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) { func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now() a := time.Now()
tx, err := d.db.(txer).Begin() tx, err := d.db.(txer).Begin()

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
) )
@ -66,6 +67,8 @@ type querySet struct {
distinct bool distinct bool
forupdate bool forupdate bool
orm *orm orm *orm
ctx context.Context
forContext bool
} }
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
@ -266,6 +269,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter. // create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)

View File

@ -390,16 +390,23 @@ type RawSeter interface {
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error) Query(args ...interface{}) (*sql.Rows, error)
//QueryContext(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
} }
// db querier // db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
// type DB interface { // type DB interface {