From e56d1b718fa74099e3d7a52632ba5e5aad65febf Mon Sep 17 00:00:00 2001 From: nlimpid Date: Sun, 18 Nov 2018 21:54:25 +0800 Subject: [PATCH] add context for db operation --- orm/db.go | 42 ++++++++++++++++++++++++++++++++---------- orm/orm_log.go | 28 ++++++++++++++++++++++++++++ orm/orm_queryset.go | 32 +++++++++++++++++++++----------- orm/types.go | 7 +++++++ 4 files changed, 88 insertions(+), 21 deletions(-) diff --git a/orm/db.go b/orm/db.go index cb807b5a..37422001 100644 --- a/orm/db.go +++ b/orm/db.go @@ -762,7 +762,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, values...) + var err error + var res sql.Result + if qs.forContext { + res, err = q.ExecContext(qs.ctx, query, values...) + } else { + res, err = q.Exec(query, values...) + } if err == nil { return res.RowsAffected() } @@ -851,11 +857,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con for i := range marks { marks[i] = "?" } - sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) + sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + var res sql.Result + if qs.forContext { + res, err = q.ExecContext(qs.ctx, query, args...) + } else { + res, err = q.Exec(query, args...) + } if err == nil { num, err := res.RowsAffected() if err != nil { @@ -978,11 +989,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) - if err != nil { - return 0, err + var err error + if qs.forContext { + rs, err = q.QueryContext(qs.ctx, query, args...) + if err != nil { + return 0, err + } + } else { + rs, err = q.Query(query, args...) + if err != nil { + return 0, err + } } - rs = r refs := make([]interface{}, colsNum) for i := range refs { @@ -1111,8 +1129,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) - + var row *sql.Row + if qs.forContext { + row = q.QueryRowContext(qs.ctx, query, args...) + } else { + row = q.QueryRow(query, args...) + } err = row.Scan(&cnt) return } diff --git a/orm/orm_log.go b/orm/orm_log.go index 979dbbc6..2a879c13 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -123,6 +123,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { return stmt, err } +func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.PrepareContext(ctx, query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { a := time.Now() res, err := d.db.Exec(query, args...) @@ -130,6 +137,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) return res, err } +func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.ExecContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { a := time.Now() res, err := d.db.Query(query, args...) @@ -137,6 +151,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) return res, err } +func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.QueryContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { a := time.Now() res := d.db.QueryRow(query, args...) @@ -144,6 +165,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { return res } +func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRowContext(ctx, query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + func (d *dbQueryLog) Begin() (*sql.Tx, error) { a := time.Now() tx, err := d.db.(txer).Begin() diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 4bab1d98..c1b5e0fb 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -55,17 +56,19 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forupdate bool - orm *orm + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []string + distinct bool + forupdate bool + orm *orm + ctx context.Context + forContext bool } var _ QuerySeter = new(querySet) @@ -266,6 +269,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } +// set context to QuerySeter. +func (o querySet) WithContext(ctx context.Context) QuerySeter { + o.ctx = ctx + o.forContext = true + return &o +} + // create new QuerySeter. func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { o := new(querySet) diff --git a/orm/types.go b/orm/types.go index 2fdc98c7..2623924f 100644 --- a/orm/types.go +++ b/orm/types.go @@ -390,16 +390,23 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) + //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) + //QueryContext(args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row + //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier type dbQuerier interface { Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } // type DB interface {