diff --git a/.gitignore b/.gitignore index e1b65291..43adebd5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ *.swp *.swo beego.iml + +_beeTmp +_beeTmp2 diff --git a/admin_test.go b/admin_test.go index 3f3612e4..205c76c2 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,10 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/toolbox" ) @@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { t.Errorf("invalid response map length: got %d want %d", len(decodedResponseBody), len(expectedResponseBody)) } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) - if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { - t.Errorf("handler returned unexpected body: got %v want %v", - decodedResponseBody, expectedResponseBody) + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] } + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) + } diff --git a/pkg/app.go b/pkg/app.go index eb672b1f..d94d56b5 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -495,3 +495,10 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) return BeeApp } + +// InsertFilterChain adds a FilterFunc built by filterChain. +// This filter will be executed before all filters. +func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) + return BeeApp +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go index f51cc109..e30f7211 100644 --- a/pkg/controller_test.go +++ b/pkg/controller_test.go @@ -19,6 +19,8 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/context" "os" "path/filepath" @@ -125,8 +127,10 @@ func TestGetUint64(t *testing.T) { } func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths") defer os.RemoveAll(dir1) defer os.RemoveAll(dir2) diff --git a/pkg/filter.go b/pkg/filter.go index 4e212e06..543d7901 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -14,10 +14,19 @@ package beego -import "github.com/astaxie/beego/pkg/context" +import ( + "strings" + + "github.com/astaxie/beego/pkg/context" +) + +// FilterChain is different from pure FilterFunc +// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned +// And all those FilterChain will be invoked before other FilterFunc +type FilterChain func(next FilterFunc) FilterFunc // FilterFunc defines a filter function which is invoked before the controller handler is executed. -type FilterFunc func(*context.Context) +type FilterFunc func(ctx *context.Context) // FilterRouter defines a filter operation which is invoked before the controller handler is executed. // It can match the URL against a pattern, and execute a filter function @@ -30,6 +39,55 @@ type FilterRouter struct { resetParams bool } +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !routerCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return mr +} + +// filter will check whether we need to execute the filter logic +// return (started, done) +func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) { + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + if f.resetParams { + preFilterParams = ctx.Input.Params() + } + if ok := f.ValidRouter(urlPath, ctx); ok { + f.filterFunc(ctx) + if f.resetParams { + ctx.Input.ResetParams() + for k, v := range preFilterParams { + ctx.Input.SetParam(k, v) + } + } + } + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + return false, false +} + // ValidRouter checks if the current request is matched by this filter. // If the request is matched, the values of the URL parameters defined // by the filter pattern are also returned. diff --git a/pkg/filter_chain_test.go b/pkg/filter_chain_test.go new file mode 100644 index 00000000..42397a60 --- /dev/null +++ b/pkg/filter_chain_test.go @@ -0,0 +1,49 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/context" +) + +func TestControllerRegister_InsertFilterChain(t *testing.T) { + + InsertFilterChain("/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", "filter-chain") + next(ctx) + } + }) + + ns := NewNamespace("/chain") + + ns.Get("/*", func(ctx *context.Context) { + ctx.Output.Body([]byte("hello")) + }) + + + r, _ := http.NewRequest("GET", "/chain/user", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.ServeHTTP(w, r) + + assert.Equal(t, "filter-chain", w.Header().Get("filter")) +} diff --git a/pkg/router.go b/pkg/router.go index 995fb767..b0c23003 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -134,11 +134,14 @@ type ControllerRegister struct { enableFilter bool filters [FinishRouter + 1][]*FilterRouter pool sync.Pool + + // the filter created by FilterChain + chainRoot *FilterRouter } // NewControllerRegister returns a new ControllerRegister. func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ + res := &ControllerRegister{ routers: make(map[string]*Tree), policies: make(map[string]*Tree), pool: sync.Pool{ @@ -147,6 +150,8 @@ func NewControllerRegister() *ControllerRegister { }, }, } + res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + return res } // Add controller handler and pattern rules to ControllerRegister. @@ -489,27 +494,28 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) + mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) return p.insertFilterRouter(pos, mr) } +// InsertFilterChain is similar to InsertFilter, +// but it will using chainRoot.filterFunc as input to build a new filterFunc +// for example, assume that chainRoot is funcA +// and we add new FilterChain +// fc := func(next) { +// return func(ctx) { +// // do something +// next(ctx) +// // do something +// } +// } +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) { + root := p.chainRoot + filterFunc := chain(root.filterFunc) + p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) +} + + // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { if pos < BeforeStatic || pos > FinishRouter { @@ -668,23 +674,9 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { var preFilterParams map[string]string for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true + b, done := filterR.filter(context, urlPath, preFilterParams) + if done { + return b } } return false @@ -692,7 +684,20 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str // Implement http.Handler interface. func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + + ctx := p.GetContext() + + ctx.Reset(rw, r) + defer p.GiveBackContext(ctx) + + var preFilterParams map[string]string + p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams) +} + +func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { startTime := time.Now() + r := ctx.Request + rw := ctx.ResponseWriter.ResponseWriter var ( runRouter reflect.Type findRouter bool @@ -701,108 +706,100 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) routerInfo *ControllerInfo isRunnable bool ) - context := p.GetContext() - context.Reset(rw, r) - - defer p.GiveBackContext(context) if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) + defer BConfig.RecoverFunc(ctx) } - context.Output.EnableGzip = BConfig.EnableGzip + ctx.Output.EnableGzip = BConfig.EnableGzip if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) + ctx.Output.Header("Server", BConfig.ServerName) } - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } + urlPath := p.getUrlPath(ctx) // filter wrong http method if !HTTPMETHOD[r.Method] { - exception("405", context) + exception("405", ctx) goto Admin } // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) { goto Admin } - serverStaticRouter(context) + serverStaticRouter(ctx) - if context.ResponseWriter.Started { + if ctx.ResponseWriter.Started { findRouter = true goto Admin } if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { + if BConfig.CopyRequestBody && !ctx.Input.IsUpload() { // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > BConfig.MaxMemory { logs.Error(errors.New("payload too large")) - exception("413", context) + exception("413", ctx) goto Admin } - context.Input.CopyBody(BConfig.MaxMemory) + ctx.Input.CopyBody(BConfig.MaxMemory) } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory) } // session init if BConfig.WebConfig.Session.SessionOn { var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) - exception("503", context) + exception("503", ctx) goto Admin } defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) + if ctx.Input.CruSession != nil { + ctx.Input.CruSession.SessionRelease(rw) } }() } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) { goto Admin } // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { + if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" { findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController + runMethod = ctx.Input.RunMethod + runRouter = ctx.Input.RunController } else { - routerInfo, findRouter = p.FindRouter(context) + routerInfo, findRouter = p.FindRouter(ctx) } // if no matches to url, throw a not found exception if !findRouter { - exception("404", context) + exception("404", ctx) goto Admin } - if splat := context.Input.Param(":splat"); splat != "" { + if splat := ctx.Input.Param(":splat"); splat != "" { for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) + ctx.Input.SetParam(strconv.Itoa(k), v) } } if routerInfo != nil { // store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) + ctx.Input.SetData("RouterPattern", routerInfo.pattern) } // execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) { goto Admin } // check policies - if p.execPolicy(context, urlPath) { + if p.execPolicy(ctx, urlPath) { goto Admin } @@ -810,22 +807,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { isRunnable = true - routerInfo.runFunction(context) + routerInfo.runFunction(ctx) } else { - exception("405", context) + exception("405", ctx) goto Admin } } else if routerInfo.routerType == routerTypeHandler { isRunnable = true - routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request) } else { runRouter = routerInfo.controllerType methodParams = routerInfo.methodParams method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut { method = http.MethodPut } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete { method = http.MethodDelete } if m, ok := routerInfo.methods[method]; ok { @@ -854,7 +851,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) + execController.Init(ctx, runRouter.Name(), runMethod, execController) // call prepare function execController.Prepare() @@ -863,14 +860,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if BConfig.WebConfig.EnableXSRF { execController.XSRFToken() if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) { execController.CheckXSRFCookie() } } execController.URLMapping() - if !context.ResponseWriter.Started { + if !ctx.ResponseWriter.Started { // exec main logic switch runMethod { case http.MethodGet: @@ -893,18 +890,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !execController.HandlerFunc(runMethod) { vc := reflect.ValueOf(execController) method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) + in := param.ConvertParams(methodParams, method.Type(), ctx) out := method.Call(in) // For backward compatibility we only handle response if we had incoming methodParams if methodParams != nil { - p.handleParamResponse(context, execController, out) + p.handleParamResponse(ctx, execController, out) } } } // render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { + if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { logs.Error(err) @@ -918,26 +915,26 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) { goto Admin } - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) { goto Admin } Admin: // admin module record QPS - statusCode := context.ResponseWriter.Status + statusCode := ctx.ResponseWriter.Status if statusCode == 0 { statusCode = 200 } - LogAccess(context, &startTime, statusCode) + LogAccess(ctx, &startTime, statusCode) timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur + ctx.ResponseWriter.Elapsed = timeDur if BConfig.Listen.EnableAdmin { pattern := "" if routerInfo != nil { @@ -956,7 +953,7 @@ Admin: if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { match := map[bool]string{true: "match", false: "nomatch"} devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), + ctx.Input.IP(), logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), timeDur.String(), match[findRouter], @@ -969,11 +966,19 @@ Admin: logs.Debug(devInfo) } // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) } } +func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string { + urlPath := ctx.Request.URL.Path + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + return urlPath +} + func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { // looping in reverse order for the case when both error and value are returned and error sets the response status code for i := len(results) - 1; i >= 0; i-- { diff --git a/pkg/template_test.go b/pkg/template_test.go index 590a7bd6..af948190 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -16,12 +16,15 @@ package beego import ( "bytes" - "github.com/astaxie/beego/pkg/testdata" - "github.com/elazarl/go-bindata-assetfs" "net/http" "os" "path/filepath" "testing" + + "github.com/elazarl/go-bindata-assetfs" + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/testdata" ) var header = `{{define "header"}} @@ -46,7 +49,9 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", @@ -56,7 +61,8 @@ func TestTemplate(t *testing.T) { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -107,7 +113,9 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp") //Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { @@ -218,7 +226,10 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", @@ -226,17 +237,22 @@ func TestTemplateLayout(t *testing.T) { if err := os.MkdirAll(dir, 0777); err != nil { t.Fatal(err) } + for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { if k == 0 { - f.WriteString(add) + _, writeErr := f.WriteString(add) + assert.Nil(t, writeErr) } else if k == 1 { - f.WriteString(layoutBlog) + _, writeErr := f.WriteString(layoutBlog) + assert.Nil(t, writeErr) } - f.Close() + clErr := f.Close() + assert.Nil(t, clErr) } } if err := AddViewPath(dir); err != nil { @@ -247,6 +263,7 @@ func TestTemplateLayout(t *testing.T) { t.Fatalf("should be 2 but got %v", len(beeTemplates)) } out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { t.Fatal(err) }