1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 04:21:00 +00:00

Merge branch 'develop-2.0' of https://github.com/astaxie/beego into frt/fix_3830

# Conflicts:
#	pkg/orm/orm_test.go
This commit is contained in:
jianzhiyao 2020-08-11 17:37:24 +08:00
commit 2d1c02e1c1
53 changed files with 782 additions and 394 deletions

View File

@ -15,13 +15,13 @@
package beego package beego
import ( import (
"crypto/tls"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"crypto/tls"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"

View File

@ -27,15 +27,18 @@ type fakeConfigContainer struct {
func (c *fakeConfigContainer) getData(key string) string { func (c *fakeConfigContainer) getData(key string) string {
return c.data[strings.ToLower(key)] return c.data[strings.ToLower(key)]
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Set(key, val string) error { func (c *fakeConfigContainer) Set(key, val string) error {
c.data[strings.ToLower(key)] = val c.data[strings.ToLower(key)] = val
return nil return nil
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) String(key string) string { func (c *fakeConfigContainer) String(key string) string {
return c.getData(key) return c.getData(key)
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key) v := c.String(key)
@ -44,6 +47,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Strings(key string) []string { func (c *fakeConfigContainer) Strings(key string) []string {
v := c.String(key) v := c.String(key)
@ -52,6 +56,7 @@ func (c *fakeConfigContainer) Strings(key string) []string {
} }
return strings.Split(v, ";") return strings.Split(v, ";")
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key) v := c.Strings(key)
@ -60,10 +65,12 @@ func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Int(key string) (int, error) { func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key)) return strconv.Atoi(c.getData(key))
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key) v, err := c.Int(key)
@ -72,10 +79,12 @@ func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Int64(key string) (int64, error) { func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64) return strconv.ParseInt(c.getData(key), 10, 64)
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key) v, err := c.Int64(key)
@ -84,10 +93,12 @@ func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Bool(key string) (bool, error) { func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return ParseBool(c.getData(key)) return ParseBool(c.getData(key))
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
v, err := c.Bool(key) v, err := c.Bool(key)
@ -96,10 +107,12 @@ func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) Float(key string) (float64, error) { func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64) return strconv.ParseFloat(c.getData(key), 64)
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key) v, err := c.Float(key)
@ -108,6 +121,7 @@ func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float
} }
return v return v
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
if v, ok := c.data[strings.ToLower(key)]; ok { if v, ok := c.data[strings.ToLower(key)]; ok {
@ -115,10 +129,12 @@ func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
} }
return nil, errors.New("key not find") return nil, errors.New("key not find")
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) {
return nil, errors.New("not implement in the fakeConfigContainer") return nil, errors.New("not implement in the fakeConfigContainer")
} }
// Deprecated: using pkg/config, we will delete this in v2.1.0 // Deprecated: using pkg/config, we will delete this in v2.1.0
func (c *fakeConfigContainer) SaveConfigFile(filename string) error { func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
return errors.New("not implement in the fakeConfigContainer") return errors.New("not implement in the fakeConfigContainer")

View File

@ -58,7 +58,6 @@ func (output *BeegoOutput) Clear() {
output.Status = 0 output.Status = 0
} }
// Header sets response header item string via given key. // Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) { func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val) output.Context.ResponseWriter.Header().Set(key, val)

4
go.sum
View File

@ -185,6 +185,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -219,6 +220,9 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c=
golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=

View File

@ -6,10 +6,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/toolbox" "github.com/astaxie/beego/pkg/toolbox"
) )
@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) {
t.Errorf("invalid response map length: got %d want %d", t.Errorf("invalid response map length: got %d want %d",
len(decodedResponseBody), len(expectedResponseBody)) len(decodedResponseBody), len(expectedResponseBody))
} }
assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody))
assert.Equal(t, 2, len(decodedResponseBody))
if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { var database, cache map[string]interface{}
t.Errorf("handler returned unexpected body: got %v want %v", if decodedResponseBody[0]["message"] == "database" {
decodedResponseBody, expectedResponseBody) database = decodedResponseBody[0]
cache = decodedResponseBody[1]
} else {
database = decodedResponseBody[1]
cache = decodedResponseBody[0]
} }
assert.Equal(t, expectedResponseBody[0], database)
assert.Equal(t, expectedResponseBody[1], cache)
} }

View File

@ -498,6 +498,7 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A
// InsertFilterChain adds a FilterFunc built by filterChain. // InsertFilterChain adds a FilterFunc built by filterChain.
// This filter will be executed before all filters. // This filter will be executed before all filters.
// the filter's behavior is like stack
func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App {
BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...)
return BeeApp return BeeApp

View File

@ -36,14 +36,23 @@ func (s *SimpleKV) GetValue() interface{} {
return s.Value return s.Value
} }
// KVs will store SimpleKV collection as map // KVs interface
type KVs struct { type KVs interface {
GetValueOr(key interface{}, defValue interface{}) interface{}
Contains(key interface{}) bool
IfContains(key interface{}, action func(value interface{})) KVs
}
// SimpleKVs will store SimpleKV collection as map
type SimpleKVs struct {
kvs map[interface{}]interface{} kvs map[interface{}]interface{}
} }
var _ KVs = new(SimpleKVs)
// GetValueOr returns the value for a given key, if non-existant // GetValueOr returns the value for a given key, if non-existant
// it returns defValue // it returns defValue
func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} { func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
v, ok := kvs.kvs[key] v, ok := kvs.kvs[key]
if ok { if ok {
return v return v
@ -52,13 +61,13 @@ func (kvs *KVs) GetValueOr(key interface{}, defValue interface{}) interface{} {
} }
// Contains checks if a key exists // Contains checks if a key exists
func (kvs *KVs) Contains(key interface{}) bool { func (kvs *SimpleKVs) Contains(key interface{}) bool {
_, ok := kvs.kvs[key] _, ok := kvs.kvs[key]
return ok return ok
} }
// IfContains invokes the action on a key if it exists // IfContains invokes the action on a key if it exists
func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs { func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs {
v, ok := kvs.kvs[key] v, ok := kvs.kvs[key]
if ok { if ok {
action(v) action(v)
@ -66,15 +75,9 @@ func (kvs *KVs) IfContains(key interface{}, action func(value interface{})) *KVs
return kvs return kvs
} }
// Put stores the value
func (kvs *KVs) Put(key interface{}, value interface{}) *KVs {
kvs.kvs[key] = value
return kvs
}
// NewKVs creates the *KVs instance // NewKVs creates the *KVs instance
func NewKVs(kvs ...KV) *KVs { func NewKVs(kvs ...KV) KVs {
res := &KVs{ res := &SimpleKVs{
kvs: make(map[interface{}]interface{}, len(kvs)), kvs: make(map[interface{}]interface{}, len(kvs)),
} }
for _, kv := range kvs { for _, kv := range kvs {

View File

@ -29,12 +29,10 @@ func TestKVs(t *testing.T) {
assert.True(t, kvs.Contains(key)) assert.True(t, kvs.Contains(key))
kvs.IfContains(key, func(value interface{}) {
kvs.Put("my-key1", "")
})
assert.True(t, kvs.Contains("my-key1"))
v := kvs.GetValueOr(key, 13) v := kvs.GetValueOr(key, 13)
assert.Equal(t, 12, v) assert.Equal(t, 12, v)
v = kvs.GetValueOr(`key-not-exists`, 8546)
assert.Equal(t, 8546, v)
} }

View File

@ -33,6 +33,7 @@ type FilterFunc func(ctx *context.Context)
// when a request with a matching URL arrives. // when a request with a matching URL arrives.
type FilterRouter struct { type FilterRouter struct {
filterFunc FilterFunc filterFunc FilterFunc
next *FilterRouter
tree *Tree tree *Tree
pattern string pattern string
returnOnOutput bool returnOnOutput bool
@ -81,6 +82,8 @@ func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterPar
ctx.Input.SetParam(k, v) ctx.Input.SetParam(k, v)
} }
} }
} else if f.next != nil {
return f.next.filter(ctx, urlPath, preFilterParams)
} }
if f.returnOnOutput && ctx.ResponseWriter.Started { if f.returnOnOutput && ctx.ResponseWriter.Started {
return true, true return true, true

View File

@ -39,7 +39,6 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
ctx.Output.Body([]byte("hello")) ctx.Output.Body([]byte("hello"))
}) })
r, _ := http.NewRequest("GET", "/chain/user", nil) r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@ -17,14 +17,11 @@ package opentracing
import ( import (
"context" "context"
"net/http" "net/http"
"strconv"
"github.com/astaxie/beego/pkg/httplib"
logKit "github.com/go-kit/kit/log" logKit "github.com/go-kit/kit/log"
opentracingKit "github.com/go-kit/kit/tracing/opentracing" opentracingKit "github.com/go-kit/kit/tracing/opentracing"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/log"
"github.com/astaxie/beego/pkg/httplib"
) )
type FilterChainBuilder struct { type FilterChainBuilder struct {
@ -38,14 +35,8 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
method := req.GetRequest().Method method := req.GetRequest().Method
host := req.GetRequest().URL.Host
path := req.GetRequest().URL.Path
proto := req.GetRequest().Proto operationName := method + "#" + req.GetRequest().URL.String()
scheme := req.GetRequest().URL.Scheme
operationName := host + path + "#" + method
span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName)
defer span.Finish() defer span.Finish()
@ -54,21 +45,24 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt
resp, err := next(spanCtx, req) resp, err := next(spanCtx, req)
if resp != nil { if resp != nil {
span.SetTag("status", strconv.Itoa(resp.StatusCode)) span.SetTag("http.status_code", resp.StatusCode)
} }
span.SetTag("http.method", method)
span.SetTag("method", method) span.SetTag("peer.hostname", req.GetRequest().URL.Host)
span.SetTag("host", host) span.SetTag("http.url", req.GetRequest().URL.String())
span.SetTag("path", path) span.SetTag("http.scheme", req.GetRequest().URL.Scheme)
span.SetTag("proto", proto) span.SetTag("span.kind", "client")
span.SetTag("scheme", scheme) span.SetTag("component", "beego")
span.LogFields(log.String("url", req.GetRequest().URL.String()))
if err != nil { if err != nil {
span.LogFields(log.String("error", err.Error())) span.SetTag("error", true)
span.SetTag("message", err.Error())
} else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) {
span.SetTag("error", true)
} }
span.SetTag("peer.address", req.GetRequest().RemoteAddr)
span.SetTag("http.proto", req.GetRequest().Proto)
if builder.CustomSpanFunc != nil { if builder.CustomSpanFunc != nil {
builder.CustomSpanFunc(span, ctx, req, resp, err) builder.CustomSpanFunc(span, ctx, req, resp, err)
} }

View File

@ -63,11 +63,13 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time
host := req.GetRequest().URL.Host host := req.GetRequest().URL.Host
path := req.GetRequest().URL.Path path := req.GetRequest().URL.Path
status := resp.StatusCode status := -1
if resp != nil {
status = resp.StatusCode
}
dur := int(endTime.Sub(startTime) / time.Millisecond) dur := int(endTime.Sub(startTime) / time.Millisecond)
builder.summaryVec.WithLabelValues(proto, scheme, method, host, path, builder.summaryVec.WithLabelValues(proto, scheme, method, host, path,
strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil)) strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil))
} }

View File

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -743,8 +744,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
} }
where, args := tables.getCondSQL(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
@ -795,9 +798,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
sets := strings.Join(cols, ", ") + " " sets := strings.Join(cols, ", ") + " "
if d.ins.SupportUpdateJoin() { if d.ins.SupportUpdateJoin() {
query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where)
} else { } else {
supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s",
Q, mi.fields.pk.column, Q,
Q, mi.table, Q,
specifyIndexes, join, where)
query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery)
} }
@ -848,8 +854,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true tables.skipEnd = true
var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
} }
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
@ -862,7 +870,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
join := tables.getJoinSQL() join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1007,6 +1015,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit) limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
if tbl.sel { if tbl.sel {
@ -1020,9 +1029,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if qs.distinct { if qs.distinct {
sqlSelect += " DISTINCT" sqlSelect += " DISTINCT"
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
if qs.forupdate { if qs.forUpdate {
query += " FOR UPDATE" query += " FOR UPDATE"
} }
@ -1158,10 +1169,13 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
groupBy := tables.getGroupSQL(qs.groups) groupBy := tables.getGroupSQL(qs.groups)
tables.getOrderSQL(qs.orders) tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
Q, mi.table, Q,
specifyIndexes, join, where, groupBy)
if groupBy != "" { if groupBy != "" {
query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
@ -1685,6 +1699,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit) limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes)
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
@ -1692,7 +1707,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
if qs.distinct { if qs.distinct {
sqlSelect += " DISTINCT" sqlSelect += " DISTINCT"
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels,
Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1786,10 +1804,6 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil return cnt, nil
} }
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
return 0, nil
}
// flag of update joined record. // flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool { func (d *dbBase) SupportUpdateJoin() bool {
return true return true
@ -1905,3 +1919,31 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
var useWay string
switch useIndex {
case hints.KeyUseIndex:
useWay = `USE`
case hints.KeyForceIndex:
useWay = `FORCE`
case hints.KeyIgnoreIndex:
useWay = `IGNORE`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`))
}

View File

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"sync" "sync"
"time" "time"
@ -363,7 +364,7 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K
var stmtCache *lru.Cache var stmtCache *lru.Cache
var stmtCacheSize int var stmtCacheSize int
maxStmtCacheSize := kvs.GetValueOr(maxStmtCacheSizeKey, 0).(int) maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int)
if maxStmtCacheSize > 0 { if maxStmtCacheSize > 0 {
_stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize)
if errC != nil { if errC != nil {
@ -398,15 +399,15 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K
detectTZ(al) detectTZ(al)
kvs.IfContains(maxIdleConnectionsKey, func(value interface{}) { kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) {
if m, ok := value.(int); ok { if m, ok := value.(int); ok {
SetMaxIdleConns(al, m) SetMaxIdleConns(al, m)
} }
}).IfContains(maxOpenConnectionsKey, func(value interface{}) { }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) {
if m, ok := value.(int); ok { if m, ok := value.(int); ok {
SetMaxOpenConns(al, m) SetMaxOpenConns(al, m)
} }
}).IfContains(connMaxLifetimeKey, func(value interface{}) { }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) {
if m, ok := value.(time.Duration); ok { if m, ok := value.(time.Duration); ok {
SetConnMaxLifetime(al, m) SetConnMaxLifetime(al, m)
} }
@ -422,21 +423,20 @@ func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV
} }
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, hints ...common.KV) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error {
var ( var (
err error err error
db *sql.DB db *sql.DB
al *alias al *alias
) )
db, err = sql.Open(driverName, dataSource) db, err = sql.Open(driverName, dataSource)
if err != nil { if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end goto end
} }
al, err = addAliasWthDB(aliasName, driverName, db, hints...) al, err = addAliasWthDB(aliasName, driverName, db, params...)
if err != nil { if err != nil {
goto end goto end
} }

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"github.com/astaxie/beego/pkg/orm/hints"
"testing" "testing"
"time" "time"
@ -23,9 +24,9 @@ import (
func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase(t *testing.T) {
err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source,
MaxIdleConnections(20), hints.MaxIdleConnections(20),
MaxOpenConnections(300), hints.MaxOpenConnections(300),
ConnMaxLifetime(time.Minute)) hints.ConnMaxLifetime(time.Minute))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias("test-params") al := getDbAlias("test-params")
@ -37,7 +38,7 @@ func TestRegisterDataBase(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -47,7 +48,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" aliasName := "TestRegisterDataBase_MaxStmtCacheSize0"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -57,7 +58,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" aliasName := "TestRegisterDataBase_MaxStmtCacheSize1"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -67,7 +68,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) {
func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) {
aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" aliasName := "TestRegisterDataBase_MaxStmtCacheSize841"
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841))
assert.Nil(t, err) assert.Nil(t, err)
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
@ -75,7 +76,6 @@ func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) {
assert.Equal(t, al.DB.stmtDecoratorsLimit, 841) assert.Equal(t, al.DB.stmtDecoratorsLimit, 841)
} }
func TestDBCache(t *testing.T) { func TestDBCache(t *testing.T) {
dataBaseCache.add("test1", &alias{}) dataBaseCache.add("test1", &alias{})
dataBaseCache.add("default", &alias{}) dataBaseCache.add("default", &alias{})

View File

@ -1,76 +0,0 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewHint_time(t *testing.T) {
key := "qweqwe"
value := time.Second
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_int(t *testing.T) {
key := "qweqwe"
value := 281230
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_float(t *testing.T) {
key := "qweqwe"
value := 21.2459753
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestMaxOpenConnections(t *testing.T) {
i := 887423
hint := MaxOpenConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxOpenConnectionsKey)
}
func TestConnMaxLifetime(t *testing.T) {
i := time.Hour
hint := ConnMaxLifetime(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), connMaxLifetimeKey)
}
func TestMaxIdleConnections(t *testing.T) {
i := 42316
hint := MaxIdleConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxIdleConnectionsKey)
}
func TestMaxStmtCacheSize(t *testing.T) {
i := 94157
hint := MaxStmtCacheSize(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), maxStmtCacheSizeKey)
}

View File

@ -16,6 +16,7 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"strings" "strings"
) )
@ -96,6 +97,29 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0 return cnt > 0
} }
func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
var hint string
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
hint = `INDEX`
case hints.KeyIgnoreIndex:
hint = `NO_INDEX`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
}
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {

View File

@ -92,6 +92,7 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0 return 0
} }
// postgresql quote is ". // postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string { func (d *dbBasePostgres) TableQuote() string {
return `"` return `"`
@ -181,6 +182,12 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0 return cnt > 0
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
return ``
}
// create new postgresql dbBaser. // create new postgresql dbBaser.
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)

View File

@ -17,7 +17,9 @@ package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"reflect" "reflect"
"strings"
"time" "time"
) )
@ -153,6 +155,25 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false return false
} }
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
}
// create new sqlite dbBaser. // create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)

View File

@ -472,6 +472,15 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits
return return
} }
// getIndexSql generate index sql.
func (t *dbTables) getIndexSql(tableName string,useIndex int, indexes []string) (clause string) {
if len(indexes) == 0 {
return
}
return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes)
}
// crete new tables collection. // crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
) )
// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation
@ -26,6 +27,7 @@ import (
var _ Ormer = new(DoNothingOrm) var _ Ormer = new(DoNothingOrm)
type DoNothingOrm struct { type DoNothingOrm struct {
} }
func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { func (d *DoNothingOrm) Read(md interface{}, cols ...string) error {
@ -52,11 +54,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{},
return false, 0, nil return false, 0, nil
} }
func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return 0, nil return 0, nil
} }
func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
return 0, nil return 0, nil
} }
@ -148,19 +150,19 @@ func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOpti
return nil, nil return nil, nil
} }
func (d *DoNothingOrm) DoTx(task func(txOrm TxOrmer) error) error { func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil return nil
} }
func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil return nil
} }
func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil return nil
} }
func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil return nil
} }

View File

@ -22,7 +22,7 @@ import (
// don't forget to call next(...) inside your Filter // don't forget to call next(...) inside your Filter
type FilterChain func(next Filter) Filter type FilterChain func(next Filter) Filter
// Filter's behavior is a little big strang. // Filter's behavior is a little big strange.
// it's only be called when users call methods of Ormer // it's only be called when users call methods of Ormer
type Filter func(ctx context.Context, inv *Invocation) type Filter func(ctx context.Context, inv *Invocation)
@ -31,6 +31,6 @@ var globalFilterChains = make([]FilterChain, 0, 4)
// AddGlobalFilterChain adds a new FilterChain // AddGlobalFilterChain adds a new FilterChain
// All orm instances built after this invocation will use this filterChain, // All orm instances built after this invocation will use this filterChain,
// but instances built before this invocation will not be affected // but instances built before this invocation will not be affected
func AddGlobalFilterChain(filterChain FilterChain) { func AddGlobalFilterChain(filterChain ...FilterChain) {
globalFilterChains = append(globalFilterChains, filterChain) globalFilterChains = append(globalFilterChains, filterChain...)
} }

View File

@ -16,6 +16,7 @@ package opentracing
import ( import (
"context" "context"
"strings"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -27,6 +28,8 @@ import (
// for example: // for example:
// if we want to trace QuerySetter // if we want to trace QuerySetter
// actually we trace invoking "QueryTable" and "QueryTableWithCtx" // actually we trace invoking "QueryTable" and "QueryTableWithCtx"
// the method Begin*, Commit and Rollback are ignored.
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
type FilterChainBuilder struct { type FilterChainBuilder struct {
// CustomSpanFunc users are able to custom their span // CustomSpanFunc users are able to custom their span
CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation)
@ -35,25 +38,34 @@ type FilterChainBuilder struct {
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
return func(ctx context.Context, inv *orm.Invocation) { return func(ctx context.Context, inv *orm.Invocation) {
operationName := builder.operationName(ctx, inv) operationName := builder.operationName(ctx, inv)
span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" {
defer span.Finish() next(ctx, inv)
return
next(spanCtx, inv)
span.SetTag("Method", inv.Method)
span.SetTag("Table", inv.GetTableName())
span.SetTag("InsideTx", inv.InsideTx)
span.SetTag("TxName", spanCtx.Value(orm.TxNameKey))
if builder.CustomSpanFunc != nil {
builder.CustomSpanFunc(span, spanCtx, inv)
} }
span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName)
defer span.Finish()
next(spanCtx, inv)
builder.buildSpan(span, spanCtx, inv)
}
}
func (builder *FilterChainBuilder) buildSpan(span opentracing.Span, ctx context.Context, inv *orm.Invocation) {
span.SetTag("orm.method", inv.Method)
span.SetTag("orm.table", inv.GetTableName())
span.SetTag("orm.insideTx", inv.InsideTx)
span.SetTag("orm.txName", ctx.Value(orm.TxNameKey))
span.SetTag("span.kind", "client")
span.SetTag("component", "beego")
if builder.CustomSpanFunc != nil {
builder.CustomSpanFunc(span, ctx, inv)
} }
} }
func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string { func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string {
if n, ok := ctx.Value(orm.TxNameKey).(string); ok { if n, ok := ctx.Value(orm.TxNameKey).(string); ok {
return inv.Method + "#" + n return inv.Method + "#tx(" + n + ")"
} }
return inv.Method + "#" + inv.GetTableName() return inv.Method + "#" + inv.GetTableName()
} }

View File

@ -17,11 +17,17 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
"reflect" "reflect"
"time" "time"
) )
const TxNameKey = "TxName" const (
TxNameKey = "TxName"
)
var _ Ormer = new(filterOrmDecorator)
var _ TxOrmer = new(filterOrmDecorator)
type filterOrmDecorator struct { type filterOrmDecorator struct {
ormer ormer
@ -40,7 +46,7 @@ func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer {
ormer: delegate, ormer: delegate,
TxBeginner: delegate, TxBeginner: delegate,
root: func(ctx context.Context, inv *Invocation) { root: func(ctx context.Context, inv *Invocation) {
inv.execute() inv.execute(ctx)
}, },
} }
@ -76,8 +82,8 @@ func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, co
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
err = f.ormer.ReadWithCtx(ctx, md, cols...) err = f.ormer.ReadWithCtx(c, md, cols...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -98,8 +104,8 @@ func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interf
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
err = f.ormer.ReadForUpdateWithCtx(ctx, md, cols...) err = f.ormer.ReadForUpdateWithCtx(c, md, cols...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -125,19 +131,19 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
ok, res, err = f.ormer.ReadOrCreateWithCtx(ctx, md, col1, cols...) ok, res, err = f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
return ok, res, err return ok, res, err
} }
func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return f.LoadRelatedWithCtx(context.Background(), md, name, args...) return f.LoadRelatedWithCtx(context.Background(), md, name, args...)
} }
func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
var ( var (
res int64 res int64
err error err error
@ -151,8 +157,8 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.LoadRelatedWithCtx(ctx, md, name, args...) res, err = f.ormer.LoadRelatedWithCtx(c, md, name, args...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -176,8 +182,8 @@ func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res = f.ormer.QueryM2MWithCtx(ctx, md, name) res = f.ormer.QueryM2MWithCtx(c, md, name)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -214,8 +220,8 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
Md: md, Md: md,
mi: mi, mi: mi,
f: func() { f: func(c context.Context) {
res = f.ormer.QueryTableWithCtx(ctx, ptrStructOrTableName) res = f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -230,7 +236,7 @@ func (f *filterOrmDecorator) DBStats() *sql.DBStats {
Method: "DBStats", Method: "DBStats",
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res = f.ormer.DBStats() res = f.ormer.DBStats()
}, },
} }
@ -255,8 +261,8 @@ func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{})
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.InsertWithCtx(ctx, md) res, err = f.ormer.InsertWithCtx(c, md)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -280,8 +286,8 @@ func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md inter
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.InsertOrUpdateWithCtx(ctx, md, colConflitAndArgs...) res, err = f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -316,8 +322,8 @@ func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, m
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.InsertMultiWithCtx(ctx, bulk, mds) res, err = f.ormer.InsertMultiWithCtx(c, bulk, mds)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -341,8 +347,8 @@ func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{},
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.UpdateWithCtx(ctx, md, cols...) res, err = f.ormer.UpdateWithCtx(c, md, cols...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -366,8 +372,8 @@ func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{},
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.ormer.DeleteWithCtx(ctx, md, cols...) res, err = f.ormer.DeleteWithCtx(c, md, cols...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -387,8 +393,8 @@ func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args
Args: []interface{}{query, args}, Args: []interface{}{query, args},
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res = f.ormer.RawWithCtx(ctx, query, args...) res = f.ormer.RawWithCtx(c, query, args...)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -403,7 +409,7 @@ func (f *filterOrmDecorator) Driver() Driver {
Method: "Driver", Method: "Driver",
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res = f.ormer.Driver() res = f.ormer.Driver()
}, },
} }
@ -433,28 +439,28 @@ func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.
Args: []interface{}{opts}, Args: []interface{}{opts},
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func() { f: func(c context.Context) {
res, err = f.TxBeginner.BeginWithCtxAndOpts(ctx, opts) res, err = f.TxBeginner.BeginWithCtxAndOpts(c, opts)
res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(ctx)) res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c))
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
return res, err return res, err
} }
func (f *filterOrmDecorator) DoTx(task func(txOrm TxOrmer) error) error { func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(context.Background(), nil, task) return f.DoTxWithCtxAndOpts(context.Background(), nil, task)
} }
func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(ctx, nil, task) return f.DoTxWithCtxAndOpts(ctx, nil, task)
} }
func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return f.DoTxWithCtxAndOpts(context.Background(), opts, task) return f.DoTxWithCtxAndOpts(context.Background(), opts, task)
} }
func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
var ( var (
err error err error
) )
@ -465,8 +471,8 @@ func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.T
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
TxName: getTxNameFromCtx(ctx), TxName: getTxNameFromCtx(ctx),
f: func() { f: func(c context.Context) {
err = f.TxBeginner.DoTxWithCtxAndOpts(ctx, opts, task) err = doTxTemplate(f, c, opts, task)
}, },
} }
f.root(ctx, inv) f.root(ctx, inv)
@ -483,7 +489,7 @@ func (f *filterOrmDecorator) Commit() error {
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
TxName: f.txName, TxName: f.txName,
f: func() { f: func(c context.Context) {
err = f.TxCommitter.Commit() err = f.TxCommitter.Commit()
}, },
} }
@ -501,7 +507,7 @@ func (f *filterOrmDecorator) Rollback() error {
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
TxName: f.txName, TxName: f.txName,
f: func() { f: func(c context.Context) {
err = f.TxCommitter.Rollback() err = f.TxCommitter.Rollback()
}, },
} }
@ -516,4 +522,3 @@ func getTxNameFromCtx(ctx context.Context) string {
} }
return txName return txName
} }

View File

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/astaxie/beego/pkg/common"
"sync" "sync"
"testing" "testing"
@ -130,49 +131,49 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) {
o := &filterMockOrm{} o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter { od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) { return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) if inv.Method == "DoTxWithCtxAndOpts" {
assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "", inv.GetTableName()) assert.Equal(t, "", inv.GetTableName())
assert.False(t, inv.InsideTx) assert.False(t, inv.InsideTx)
}
next(ctx, inv) next(ctx, inv)
} }
}) })
err := od.DoTx(func(txOrm TxOrmer) error { err := od.DoTx(func(c context.Context, txOrm TxOrmer) error {
return errors.New("tx error") return nil
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "tx error", err.Error())
err = od.DoTxWithCtx(context.Background(), func(txOrm TxOrmer) error { err = od.DoTxWithCtx(context.Background(), func(c context.Context, txOrm TxOrmer) error {
return errors.New("tx ctx error") return nil
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "tx ctx error", err.Error())
err = od.DoTxWithOpts(nil, func(txOrm TxOrmer) error { err = od.DoTxWithOpts(nil, func(c context.Context, txOrm TxOrmer) error {
return errors.New("tx opts error") return nil
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "tx opts error", err.Error())
od = NewFilterOrmDecorator(o, func(next Filter) Filter { od = NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) { return func(ctx context.Context, inv *Invocation) {
assert.Equal(t, "DoTxWithCtxAndOpts", inv.Method) if inv.Method == "DoTxWithCtxAndOpts" {
assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "", inv.GetTableName()) assert.Equal(t, "", inv.GetTableName())
assert.Equal(t, "do tx name", inv.TxName) assert.Equal(t, "do tx name", inv.TxName)
assert.False(t, inv.InsideTx) assert.False(t, inv.InsideTx)
}
next(ctx, inv) next(ctx, inv)
} }
}) })
ctx := context.WithValue(context.Background(), TxNameKey, "do tx name") ctx := context.WithValue(context.Background(), TxNameKey, "do tx name")
err = od.DoTxWithCtxAndOpts(ctx, nil, func(txOrm TxOrmer) error { err = od.DoTxWithCtxAndOpts(ctx, nil, func(c context.Context, txOrm TxOrmer) error {
return errors.New("tx ctx opts error") return nil
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "tx ctx opts error", err.Error())
} }
func TestFilterOrmDecorator_Driver(t *testing.T) { func TestFilterOrmDecorator_Driver(t *testing.T) {
@ -347,6 +348,8 @@ func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) {
assert.Equal(t, int64(13), i) assert.Equal(t, int64(13), i)
} }
var _ Ormer = new(filterMockOrm)
// filterMockOrm is only used in this test file // filterMockOrm is only used in this test file
type filterMockOrm struct { type filterMockOrm struct {
DoNothingOrm DoNothingOrm
@ -360,7 +363,7 @@ func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}
return errors.New("read for update error") return errors.New("read for update error")
} }
func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) { func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
return 99, errors.New("load related error") return 99, errors.New("load related error")
} }
@ -376,8 +379,8 @@ func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int6
return 100, errors.New("insert error") return 100, errors.New("insert error")
} }
func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(c context.Context, txOrm TxOrmer) error) error {
return task(nil) return task(ctx, nil)
} }
func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {

View File

@ -28,4 +28,5 @@ func TestAddGlobalFilterChain(t *testing.T) {
} }
}) })
assert.Equal(t, 1, len(globalFilterChains)) assert.Equal(t, 1, len(globalFilterChains))
globalFilterChains = nil
} }

View File

@ -12,13 +12,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package hints
import ( import (
"github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/common"
"time" "time"
) )
const (
//db level
KeyMaxIdleConnections = iota
KeyMaxOpenConnections
KeyConnMaxLifetime
KeyMaxStmtCacheSize
//query level
KeyForceIndex
KeyUseIndex
KeyIgnoreIndex
KeyForUpdate
KeyLimit
KeyOffset
KeyOrderBy
KeyRelDepth
)
type Hint struct { type Hint struct {
key interface{} key interface{}
value interface{} value interface{}
@ -36,33 +54,71 @@ func (s *Hint) GetValue() interface{} {
return s.value return s.value
} }
const (
maxIdleConnectionsKey = "MaxIdleConnections"
maxOpenConnectionsKey = "MaxOpenConnections"
connMaxLifetimeKey = "ConnMaxLifetime"
maxStmtCacheSizeKey = "MaxStmtCacheSize"
)
var _ common.KV = new(Hint) var _ common.KV = new(Hint)
// MaxIdleConnections return a hint about MaxIdleConnections // MaxIdleConnections return a hint about MaxIdleConnections
func MaxIdleConnections(v int) *Hint { func MaxIdleConnections(v int) *Hint {
return NewHint(maxIdleConnectionsKey, v) return NewHint(KeyMaxIdleConnections, v)
} }
// MaxOpenConnections return a hint about MaxOpenConnections // MaxOpenConnections return a hint about MaxOpenConnections
func MaxOpenConnections(v int) *Hint { func MaxOpenConnections(v int) *Hint {
return NewHint(maxOpenConnectionsKey, v) return NewHint(KeyMaxOpenConnections, v)
} }
// ConnMaxLifetime return a hint about ConnMaxLifetime // ConnMaxLifetime return a hint about ConnMaxLifetime
func ConnMaxLifetime(v time.Duration) *Hint { func ConnMaxLifetime(v time.Duration) *Hint {
return NewHint(connMaxLifetimeKey, v) return NewHint(KeyConnMaxLifetime, v)
} }
// MaxStmtCacheSize return a hint about MaxStmtCacheSize // MaxStmtCacheSize return a hint about MaxStmtCacheSize
func MaxStmtCacheSize(v int) *Hint { func MaxStmtCacheSize(v int) *Hint {
return NewHint(maxStmtCacheSizeKey, v) return NewHint(KeyMaxStmtCacheSize, v)
}
// ForceIndex return a hint about ForceIndex
func ForceIndex(indexes ...string) *Hint {
return NewHint(KeyForceIndex, indexes)
}
// UseIndex return a hint about UseIndex
func UseIndex(indexes ...string) *Hint {
return NewHint(KeyUseIndex, indexes)
}
// IgnoreIndex return a hint about IgnoreIndex
func IgnoreIndex(indexes ...string) *Hint {
return NewHint(KeyIgnoreIndex, indexes)
}
// ForUpdate return a hint about ForUpdate
func ForUpdate() *Hint {
return NewHint(KeyForUpdate, true)
}
// DefaultRelDepth return a hint about DefaultRelDepth
func DefaultRelDepth() *Hint {
return NewHint(KeyRelDepth, true)
}
// RelDepth return a hint about RelDepth
func RelDepth(d int) *Hint {
return NewHint(KeyRelDepth, d)
}
// Limit return a hint about Limit
func Limit(d int64) *Hint {
return NewHint(KeyLimit, d)
}
// Offset return a hint about Offset
func Offset(d int64) *Hint {
return NewHint(KeyOffset, d)
}
// OrderBy return a hint about OrderBy
func OrderBy(s string) *Hint {
return NewHint(KeyOrderBy, s)
} }
// NewHint return a hint // NewHint return a hint

View File

@ -0,0 +1,154 @@
// Copyright 2020 beego-dev
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package hints
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewHint_time(t *testing.T) {
key := "qweqwe"
value := time.Second
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_int(t *testing.T) {
key := "qweqwe"
value := 281230
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestNewHint_float(t *testing.T) {
key := "qweqwe"
value := 21.2459753
hint := NewHint(key, value)
assert.Equal(t, hint.GetKey(), key)
assert.Equal(t, hint.GetValue(), value)
}
func TestMaxOpenConnections(t *testing.T) {
i := 887423
hint := MaxOpenConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections)
}
func TestConnMaxLifetime(t *testing.T) {
i := time.Hour
hint := ConnMaxLifetime(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime)
}
func TestMaxIdleConnections(t *testing.T) {
i := 42316
hint := MaxIdleConnections(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections)
}
func TestMaxStmtCacheSize(t *testing.T) {
i := 94157
hint := MaxStmtCacheSize(i)
assert.Equal(t, hint.GetValue(), i)
assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize)
}
func TestForceIndex(t *testing.T) {
s := []string{`f_index1`, `f_index2`, `f_index3`}
hint := ForceIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyForceIndex)
}
func TestForceIndex_0(t *testing.T) {
var s []string
hint := ForceIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyForceIndex)
}
func TestIgnoreIndex(t *testing.T) {
s := []string{`i_index1`, `i_index2`, `i_index3`}
hint := IgnoreIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
}
func TestIgnoreIndex_0(t *testing.T) {
var s []string
hint := IgnoreIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
}
func TestUseIndex(t *testing.T) {
s := []string{`u_index1`, `u_index2`, `u_index3`}
hint := UseIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyUseIndex)
}
func TestUseIndex_0(t *testing.T) {
var s []string
hint := UseIndex(s...)
assert.Equal(t, hint.GetValue(), s)
assert.Equal(t, hint.GetKey(), KeyUseIndex)
}
func TestForUpdate(t *testing.T) {
hint := ForUpdate()
assert.Equal(t, hint.GetValue(), true)
assert.Equal(t, hint.GetKey(), KeyForUpdate)
}
func TestDefaultRelDepth(t *testing.T) {
hint := DefaultRelDepth()
assert.Equal(t, hint.GetValue(), true)
assert.Equal(t, hint.GetKey(), KeyRelDepth)
}
func TestRelDepth(t *testing.T) {
hint := RelDepth(157965)
assert.Equal(t, hint.GetValue(), 157965)
assert.Equal(t, hint.GetKey(), KeyRelDepth)
}
func TestLimit(t *testing.T) {
hint := Limit(1579625)
assert.Equal(t, hint.GetValue(), int64(1579625))
assert.Equal(t, hint.GetKey(), KeyLimit)
}
func TestOffset(t *testing.T) {
hint := Offset(int64(1572123965))
assert.Equal(t, hint.GetValue(), int64(1572123965))
assert.Equal(t, hint.GetKey(), KeyOffset)
}
func TestOrderBy(t *testing.T) {
hint := OrderBy(`-ID`)
assert.Equal(t, hint.GetValue(), `-ID`)
assert.Equal(t, hint.GetKey(), KeyOrderBy)
}

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"time" "time"
) )
@ -28,7 +29,7 @@ type Invocation struct {
mi *modelInfo mi *modelInfo
// f is the Orm operation // f is the Orm operation
f func() f func(ctx context.Context)
// insideTx indicates whether this is inside a transaction // insideTx indicates whether this is inside a transaction
InsideTx bool InsideTx bool
@ -43,6 +44,6 @@ func (inv *Invocation) GetTableName() string {
return "" return ""
} }
func (inv *Invocation) execute() { func (inv *Invocation) execute(ctx context.Context) {
inv.f() inv.f(ctx)
} }

View File

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"os" "os"
"strings" "strings"
"time" "time"
@ -27,7 +28,6 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now // As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb" // _ "github.com/pingcap/tidb"
) )
// A slice string field. // A slice string field.
@ -381,6 +381,15 @@ type InLine struct {
Email string Email string
} }
type Index struct {
// Common Fields
Id int `orm:"column(id)"`
// Other Fields
F1 int `orm:"column(f1);index"`
F2 int `orm:"column(f2);index"`
}
func NewInLine() *InLine { func NewInLine() *InLine {
return new(InLine) return new(InLine)
} }
@ -493,7 +502,7 @@ func init() {
os.Exit(2) os.Exit(2)
} }
err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20))
if err != nil { if err != nil {
panic(fmt.Sprintf("can not register database: %v", err)) panic(fmt.Sprintf("can not register database: %v", err))

View File

@ -59,6 +59,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/common" "github.com/astaxie/beego/pkg/common"
"github.com/astaxie/beego/pkg/orm/hints"
"os" "os"
"reflect" "reflect"
"time" "time"
@ -99,6 +100,7 @@ type ormBase struct {
var _ DQL = new(ormBase) var _ DQL = new(ormBase)
var _ DML = new(ormBase) var _ DML = new(ormBase)
var _ DriverGetter = new(ormBase)
// get model info and model reflect value // get model info and model reflect value
func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
@ -302,11 +304,10 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
// for _,tag := range post.Tags{...} // for _,tag := range post.Tags{...}
// //
// make sure the relation is defined in model struct tags. // make sure the relation is defined in model struct tags.
func (o *ormBase) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (o *ormBase) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) {
return o.LoadRelatedWithCtx(context.Background(), md, name, args...) return o.LoadRelatedWithCtx(context.Background(), md, name, args...)
} }
func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) {
func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name) _, fi, ind, qseter := o.queryRelated(md, name)
qs := qseter.(*querySet) qs := qseter.(*querySet)
@ -314,24 +315,29 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
var relDepth int var relDepth int
var limit, offset int64 var limit, offset int64
var order string var order string
for i, arg := range args {
switch i { kvs := common.NewKVs(args...)
case 0: kvs.IfContains(hints.KeyRelDepth, func(value interface{}) {
if v, ok := arg.(bool); ok { if v, ok := value.(bool); ok {
if v { if v {
relDepth = DefaultRelsDepth relDepth = DefaultRelsDepth
} }
} else if v, ok := arg.(int); ok { } else if v, ok := value.(int); ok {
relDepth = v relDepth = v
} }
case 1: }).IfContains(hints.KeyLimit, func(value interface{}) {
limit = ToInt64(arg) if v, ok := value.(int64); ok {
case 2: limit = v
offset = ToInt64(arg)
case 3:
order, _ = arg.(string)
} }
}).IfContains(hints.KeyOffset, func(value interface{}) {
if v, ok := value.(int64); ok {
offset = v
} }
}).IfContains(hints.KeyOrderBy, func(value interface{}) {
if v, ok := value.(string); ok {
order = v
}
})
switch fi.fieldType { switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne: case RelOneToOne, RelForeignKey, RelReverseOne:
@ -522,19 +528,24 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO
return taskTxOrm, nil return taskTxOrm, nil
} }
func (o *orm) DoTx(task func(txOrm TxOrmer) error) error { func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
return o.DoTxWithCtx(context.Background(), task) return o.DoTxWithCtx(context.Background(), task)
} }
func (o *orm) DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error { func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
return o.DoTxWithCtxAndOpts(ctx, nil, task) return o.DoTxWithCtxAndOpts(ctx, nil, task)
} }
func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return o.DoTxWithCtxAndOpts(context.Background(), opts, task) return o.DoTxWithCtxAndOpts(context.Background(), opts, task)
} }
func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error { func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return doTxTemplate(o, ctx, opts, task)
}
func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions,
task func(ctx context.Context, txOrm TxOrmer) error) error {
_txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts)
if err != nil { if err != nil {
return err return err
@ -553,9 +564,8 @@ func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task
} }
} }
}() }()
var taskTxOrm = _txOrm var taskTxOrm = _txOrm
err = task(taskTxOrm) err = task(ctx, taskTxOrm)
panicked = false panicked = false
return err return err
} }
@ -582,18 +592,11 @@ func NewOrm() Ormer {
// NewOrmUsingDB create new orm with the name // NewOrmUsingDB create new orm with the name
func NewOrmUsingDB(aliasName string) Ormer { func NewOrmUsingDB(aliasName string) Ormer {
o := new(orm)
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
o.alias = al return newDBWithAlias(al)
if Debug {
o.db = newDbQueryLog(al, al.DB)
} else {
o.db = al.DB
}
} else { } else {
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName)) panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
} }
return o
} }
// NewOrmWithDB create a new ormer object with specify *sql.DB for query // NewOrmWithDB create a new ormer object with specify *sql.DB for query
@ -603,14 +606,21 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV)
return nil, err return nil, err
} }
return newDBWithAlias(al), nil
}
func newDBWithAlias(al *alias) Ormer {
o := new(orm) o := new(orm)
o.alias = al o.alias = al
if Debug { if Debug {
o.db = newDbQueryLog(o.alias, db) o.db = newDbQueryLog(al, al.DB)
} else { } else {
o.db = db o.db = al.DB
} }
return o, nil if len(globalFilterChains) > 0 {
return NewFilterOrmDecorator(o, globalFilterChains...)
}
return o
} }

View File

@ -127,10 +127,7 @@ var _ txer = new(dbQueryLog)
var _ txEnder = new(dbQueryLog) var _ txEnder = new(dbQueryLog)
func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
a := time.Now() return d.PrepareContext(context.Background(), query)
stmt, err := d.db.Prepare(query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
} }
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
@ -141,10 +138,7 @@ func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stm
} }
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now() return d.ExecContext(context.Background(), query, args...)
res, err := d.db.Exec(query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
} }
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
@ -155,10 +149,7 @@ func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...inte
} }
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now() return d.QueryContext(context.Background(), query, args...)
res, err := d.db.Query(query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
} }
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
@ -169,10 +160,7 @@ func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...int
} }
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now() return d.QueryRowContext(context.Background(), query, args...)
res := d.db.QueryRow(query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
} }
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
@ -183,10 +171,7 @@ func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...
} }
func (d *dbQueryLog) Begin() (*sql.Tx, error) { func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now() return d.BeginTx(context.Background(), nil)
tx, err := d.db.(txer).Begin()
debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err)
return tx, err
} }
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
) )
type colValue struct { type colValue struct {
@ -71,7 +72,9 @@ type querySet struct {
groups []string groups []string
orders []string orders []string
distinct bool distinct bool
forupdate bool forUpdate bool
useIndex int
indexes []string
orm *ormBase orm *ormBase
ctx context.Context ctx context.Context
forContext bool forContext bool
@ -148,7 +151,28 @@ func (o querySet) Distinct() QuerySeter {
// add FOR UPDATE to SELECT // add FOR UPDATE to SELECT
func (o querySet) ForUpdate() QuerySeter { func (o querySet) ForUpdate() QuerySeter {
o.forupdate = true o.forUpdate = true
return &o
}
// ForceIndex force index for query
func (o querySet) ForceIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyForceIndex
o.indexes = indexes
return &o
}
// UseIndex use index for query
func (o querySet) UseIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyUseIndex
o.indexes = indexes
return &o
}
// IgnoreIndex ignore index for query
func (o querySet) IgnoreIndex(indexes ...string) QuerySeter {
o.useIndex = hints.KeyIgnoreIndex
o.indexes = indexes
return &o return &o
} }

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/astaxie/beego/pkg/orm/hints"
"io/ioutil" "io/ioutil"
"math" "math"
"os" "os"
@ -200,6 +201,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk)) RegisterModel(new(PtrPk))
RegisterModel(new(Index))
RegisterModel(new(StrPk)) RegisterModel(new(StrPk))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
@ -225,6 +227,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk)) RegisterModel(new(PtrPk))
RegisterModel(new(Index))
RegisterModel(new(StrPk)) RegisterModel(new(StrPk))
BootStrap() BootStrap()
@ -795,6 +798,32 @@ func TestExpr(t *testing.T) {
// throwFail(t, AssertIs(num, 3)) // throwFail(t, AssertIs(num, 3))
} }
func TestSpecifyIndex(t *testing.T) {
var index *Index
index = &Index{
F1: 1,
F2: 2,
}
_, _ = dORM.Insert(index)
throwFailNow(t, AssertIs(index.Id, 1))
index = &Index{
F1: 3,
F2: 4,
}
_, _ = dORM.Insert(index)
throwFailNow(t, AssertIs(index.Id, 2))
_ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`index_f1`).One(index)
throwFailNow(t, AssertIs(index.F2, 2))
_ = dORM.QueryTable(&Index{}).Filter(`f2`, `4`).UseIndex(`index_f2`).One(index)
throwFailNow(t, AssertIs(index.F1, 3))
_ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`index_f1`, `index_f2`).One(index)
throwFailNow(t, AssertIs(index.F2, 2))
}
func TestOperators(t *testing.T) { func TestOperators(t *testing.T) {
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
num, err := qs.Filter("user_name", "slene").Count() num, err := qs.Filter("user_name", "slene").Count()
@ -1281,24 +1310,32 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true) num, err = dORM.LoadRelated(&user, "Posts", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1) num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.Limit(1))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.OrderBy("-Id"))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") num, err = dORM.LoadRelated(&user, "Posts",
hints.DefaultRelDepth(),
hints.Limit(1),
hints.Offset(1),
hints.OrderBy("Id"))
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
@ -1320,7 +1357,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User == nil, false))
throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) throwFailNow(t, AssertIs(profile.User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&profile, "User", true) num, err = dORM.LoadRelated(&profile, "User", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User == nil, false))
@ -1337,7 +1374,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile == nil, false))
throwFailNow(t, AssertIs(user.Profile.Age, 30)) throwFailNow(t, AssertIs(user.Profile.Age, 30))
num, err = dORM.LoadRelated(&user, "Profile", true) num, err = dORM.LoadRelated(&user, "Profile", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile == nil, false))
@ -1357,7 +1394,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User == nil, false))
throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) throwFailNow(t, AssertIs(post.User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&post, "User", true) num, err = dORM.LoadRelated(&post, "User", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User == nil, false))
@ -1377,7 +1414,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(len(post.Tags), 2))
throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) throwFailNow(t, AssertIs(post.Tags[0].Name, "golang"))
num, err = dORM.LoadRelated(&post, "Tags", true) num, err = dORM.LoadRelated(&post, "Tags", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(len(post.Tags), 2))
@ -1398,7 +1435,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true) num, err = dORM.LoadRelated(&tag, "Posts", hints.DefaultRelDepth())
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))

View File

@ -17,6 +17,7 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/astaxie/beego/pkg/common"
"reflect" "reflect"
"time" "time"
) )
@ -95,10 +96,10 @@ type TxBeginner interface {
BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error)
//closure control transaction //closure control transaction
DoTx(task func(txOrm TxOrmer) error) error DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithCtx(ctx context.Context, task func(txOrm TxOrmer) error) error DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithOpts(opts *sql.TxOptions, task func(txOrm TxOrmer) error) error DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(txOrm TxOrmer) error) error DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
} }
type TxCommitter interface { type TxCommitter interface {
@ -175,14 +176,14 @@ type DQL interface {
// example: // example:
// Ormer.LoadRelated(post,"Tags") // Ormer.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...} // for _,tag := range post.Tags{...}
// args[0] bool true useDefaultRelsDepth ; false depth 0 // hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0
// args[0] int loadRelationDepth // hints.RelDepth loadRelationDepth
// args[1] int limit default limit 1000 // hints.Limit limit default limit 1000
// args[2] int offset default offset 0 // hints.Offset int offset default offset 0
// args[3] string order for example : "-Id" // hints.OrderBy string order for example : "-Id"
// make sure the relation is defined in model struct tags. // make sure the relation is defined in model struct tags.
LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error)
LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error)
// create a models to models queryer // create a models to models queryer
// for example: // for example:
@ -282,6 +283,21 @@ type QuerySeter interface {
// for example: // for example:
// qs.OrderBy("-status") // qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter OrderBy(exprs ...string) QuerySeter
// add FORCE INDEX expression.
// for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
ForceIndex(indexes ...string) QuerySeter
// add USE INDEX expression.
// for example:
// qs.UseIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
UseIndex(indexes ...string) QuerySeter
// add IGNORE INDEX expression.
// for example:
// qs.IgnoreIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
IgnoreIndex(indexes ...string) QuerySeter
// set relation model to query together. // set relation model to query together.
// it will query relation models and assign to parent model. // it will query relation models and assign to parent model.
// for example: // for example:
@ -527,24 +543,27 @@ type txEnder interface {
// base database struct // base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool
OperatorSQL(string) string OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)
@ -559,4 +578,6 @@ type dbBaser interface {
IndexExists(dbQuerier, string, string) bool IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error setval(dbQuerier, *modelInfo, []string) error
GenerateSpecifyIndex(tableName string,useIndex int ,indexes []string) string
} }

View File

@ -472,8 +472,9 @@ func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain
root := p.chainRoot root := p.chainRoot
filterFunc := chain(root.filterFunc) filterFunc := chain(root.filterFunc)
p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...)
} p.chainRoot.next = root
}
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {

View File

@ -31,7 +31,6 @@ type FilterChainBuilder struct {
CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context)
} }
func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc {
return func(ctx *beegoCtx.Context) { return func(ctx *beegoCtx.Context) {
var ( var (
@ -55,9 +54,21 @@ func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.Filt
next(ctx) next(ctx)
// if you think we need to do more things, feel free to create an issue to tell us // if you think we need to do more things, feel free to create an issue to tell us
span.SetTag("status", ctx.Output.Status) span.SetTag("http.status_code", ctx.ResponseWriter.Status)
span.SetTag("method", ctx.Input.Method()) span.SetTag("http.method", ctx.Input.Method())
span.SetTag("route", ctx.Input.GetData("RouterPattern")) span.SetTag("peer.hostname", ctx.Request.Host)
span.SetTag("http.url", ctx.Request.URL.String())
span.SetTag("http.scheme", ctx.Request.URL.Scheme)
span.SetTag("span.kind", "server")
span.SetTag("component", "beego")
if ctx.Output.IsServerError() || ctx.Output.IsClientError() {
span.SetTag("error", true)
}
span.SetTag("peer.address", ctx.Request.RemoteAddr)
span.SetTag("http.proto", ctx.Request.Proto)
span.SetTag("beego.route", ctx.Input.GetData("RouterPattern"))
if builder.CustomSpanFunc != nil { if builder.CustomSpanFunc != nil {
builder.CustomSpanFunc(span, ctx) builder.CustomSpanFunc(span, ctx)
} }
@ -70,7 +81,7 @@ func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string {
// TODO, if we support multiple servers, this need to be changed // TODO, if we support multiple servers, this need to be changed
route, found := beego.BeeApp.Handlers.FindRouter(ctx) route, found := beego.BeeApp.Handlers.FindRouter(ctx)
if found { if found {
operationName = route.GetPattern() operationName = ctx.Input.Method() + "#" + route.GetPattern()
} }
return operationName return operationName
} }