From 15f04b8da467a5dc4015e58ce569eea725cd892d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=B1=AA=E8=B4=B5?= Date: Wed, 29 Jul 2020 21:57:16 +0800 Subject: [PATCH 01/10] add env BEEGO_CONFIG_PATH --- config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config.go b/config.go index b6c9a99c..92aa3bbd 100644 --- a/config.go +++ b/config.go @@ -150,6 +150,9 @@ func init() { filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" } appConfigPath = filepath.Join(WorkPath, "conf", filename) + if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { + appConfigPath = configPath + } if !utils.FileExists(appConfigPath) { appConfigPath = filepath.Join(AppPath, "conf", filename) if !utils.FileExists(appConfigPath) { From 15e11931fcd85128cada95daba8b72b789d34a97 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 10:53:30 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=20BConfig.Lis?= =?UTF-8?q?ten.ClientAuth=20=E5=AD=97=E6=AE=B5=E7=9A=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=A4=84=E7=90=86=E3=80=82=E5=BD=93=E6=8C=87=E5=AE=9A=E4=BA=86?= =?UTF-8?q?=E8=AF=A5=E9=85=8D=E7=BD=AE=E6=97=B6=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=9A=84=E5=80=BC=E6=9D=A5=E4=BD=9C=E4=B8=BA?= =?UTF-8?q?=E9=AA=8C=E8=AF=81=E5=AE=A2=E6=88=B7=E7=AB=AF=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=E3=80=82=E5=A6=82=E6=9E=9C=E6=B2=A1=E6=8C=87=E5=AE=9A?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E7=94=A8=E9=BB=98=E8=AE=A4=E5=80=BC=20tls.Re?= =?UTF-8?q?quireAndVerifyClientCert?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/app.go b/app.go index f3fe6f7b..5e595bb6 100644 --- a/app.go +++ b/app.go @@ -195,10 +195,15 @@ func (app *App) Run(mws ...MiddleWare) { return } pool.AppendCertsFromPEM(data) - app.Server.TLSConfig = &tls.Config{ + tlsConfig := tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, } + if string(BConfig.Listen.ClientAuth) != "" { + tslConfig.ClientAuth = BConfig.Listen.ClientAuth + } else { + tslConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + app.Server.TLSConfig = &tslConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From 513a4afff14c056f1fe5a844d8a60123987491f3 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 10:59:32 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E5=AF=B9=20Listen=20=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E4=BD=93=E5=A2=9E=E5=8A=A0=20ClientAuth=20=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对 Listen 结构体增加 ClientAuth 字段,赋予默认配置对象该字段值为 tls.VerifyClientCertIfGiven,与原代码逻辑的默认值保持一致 --- config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config.go b/config.go index 92aa3bbd..fef6c482 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ import ( "reflect" "runtime" "strings" + "crypto/tls" "github.com/astaxie/beego/config" "github.com/astaxie/beego/context" @@ -65,6 +66,7 @@ type Listen struct { HTTPSCertFile string HTTPSKeyFile string TrustCaFile string + ClientAuth tls.ClientAuthType EnableAdmin bool AdminAddr string AdminPort int @@ -234,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: tls.VerifyClientCertIfGiven, }, WebConfig: WebConfig{ AutoRender: true, From 9d23e5a3fb23df1ac647660c13f566fa17e81c1e Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:03:32 +0800 Subject: [PATCH 04/10] =?UTF-8?q?=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=86=99=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app.go b/app.go index 5e595bb6..85fe7e5d 100644 --- a/app.go +++ b/app.go @@ -197,11 +197,10 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) tlsConfig := tls.Config{ ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, } if string(BConfig.Listen.ClientAuth) != "" { tslConfig.ClientAuth = BConfig.Listen.ClientAuth - } else { - tslConfig.ClientAuth = tls.RequireAndVerifyClientCert } app.Server.TLSConfig = &tslConfig } From c46ba862157b9b6d5d39561985fa7f4d8a8bdd18 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:18:14 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=AC=94=E8=AF=AF?= =?UTF-8?q?=E4=BA=A7=E7=94=9F=E7=9A=84=E6=8B=BC=E5=86=99=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.go b/app.go index 85fe7e5d..20af4ce8 100644 --- a/app.go +++ b/app.go @@ -200,7 +200,7 @@ func (app *App) Run(mws ...MiddleWare) { ClientAuth: tls.RequireAndVerifyClientCert, } if string(BConfig.Listen.ClientAuth) != "" { - tslConfig.ClientAuth = BConfig.Listen.ClientAuth + tlsConfig.ClientAuth = BConfig.Listen.ClientAuth } app.Server.TLSConfig = &tslConfig } From 0815e77f9af9336b17a50d6a4226aad9a1969731 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 11:20:22 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=AC=94=E8=AF=AF?= =?UTF-8?q?=E4=BA=A7=E7=94=9F=E7=9A=84=E6=8B=BC=E5=86=99=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.go b/app.go index 20af4ce8..49cef256 100644 --- a/app.go +++ b/app.go @@ -202,7 +202,7 @@ func (app *App) Run(mws ...MiddleWare) { if string(BConfig.Listen.ClientAuth) != "" { tlsConfig.ClientAuth = BConfig.Listen.ClientAuth } - app.Server.TLSConfig = &tslConfig + app.Server.TLSConfig = &tlsConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From 520380416557c76a59a2c30398e027c27ad70d36 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 14:46:17 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E8=B0=83=E6=95=B4=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E4=B8=AD=E7=9A=84=20ClientAuth=20=E5=80=BC?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E4=B9=8B=E4=B8=8E=E5=8E=9F=E6=9D=A5=E7=9A=84?= =?UTF-8?q?=E8=A1=8C=E4=B8=BA=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index fef6c482..0c995293 100644 --- a/config.go +++ b/config.go @@ -236,7 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, - ClientAuth: tls.VerifyClientCertIfGiven, + ClientAuth: tls.RequireAndVerifyClientCert, }, WebConfig: WebConfig{ AutoRender: true, From 7831638f3793d1b3058ffa0cbc1e8d0c818bd3e8 Mon Sep 17 00:00:00 2001 From: "Mr. Myy" <1135038815@qq.com> Date: Thu, 30 Jul 2020 14:48:46 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=9A=84=E6=9D=A1=E4=BB=B6=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/app.go b/app.go index 49cef256..3dee8999 100644 --- a/app.go +++ b/app.go @@ -195,14 +195,10 @@ func (app *App) Run(mws ...MiddleWare) { return } pool.AppendCertsFromPEM(data) - tlsConfig := tls.Config{ + app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: BConfig.Listen.ClientAuth, } - if string(BConfig.Listen.ClientAuth) != "" { - tlsConfig.ClientAuth = BConfig.Listen.ClientAuth - } - app.Server.TLSConfig = &tlsConfig } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err) From a0d1c42daca7af6cbf3a6c73a89793a06cdbd4c7 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 3 Aug 2020 21:03:08 +0800 Subject: [PATCH 09/10] XSRF add secure and http only flag --- context/context.go | 2 +- context/context_test.go | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/context/context.go b/context/context.go index de248ed2..7c161ac0 100644 --- a/context/context.go +++ b/context/context.go @@ -150,7 +150,7 @@ func (ctx *Context) XSRFToken(key string, expire int64) string { token, ok := ctx.GetSecureCookie(key, "_xsrf") if !ok { token = string(utils.RandomCreateBytes(32)) - ctx.SetSecureCookie(key, "_xsrf", token, expire) + ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "", true, true) } ctx._xsrfToken = token } diff --git a/context/context_test.go b/context/context_test.go index 7c0535e0..e81e8191 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -17,7 +17,10 @@ package context import ( "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestXsrfReset_01(t *testing.T) { @@ -44,4 +47,8 @@ func TestXsrfReset_01(t *testing.T) { if token == c._xsrfToken { t.FailNow() } + + ck := c.ResponseWriter.Header().Get("Set-Cookie") + assert.True(t, strings.Contains(ck, "Secure")) + assert.True(t, strings.Contains(ck, "HttpOnly")) } From 79ffef90e37ac2692fe694b8ebae01ce2fbe2794 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 31 Jul 2020 21:43:11 +0800 Subject: [PATCH 10/10] support filter chain --- .gitignore | 3 + admin_test.go | 18 +++- pkg/app.go | 7 ++ pkg/controller_test.go | 8 +- pkg/filter.go | 62 ++++++++++++- pkg/filter_chain_test.go | 49 +++++++++++ pkg/router.go | 185 ++++++++++++++++++++------------------- pkg/template_test.go | 37 +++++--- 8 files changed, 261 insertions(+), 108 deletions(-) create mode 100644 pkg/filter_chain_test.go diff --git a/.gitignore b/.gitignore index e1b65291..43adebd5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ *.swp *.swo beego.iml + +_beeTmp +_beeTmp2 diff --git a/admin_test.go b/admin_test.go index 3f3612e4..205c76c2 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,10 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/toolbox" ) @@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { t.Errorf("invalid response map length: got %d want %d", len(decodedResponseBody), len(expectedResponseBody)) } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) - if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { - t.Errorf("handler returned unexpected body: got %v want %v", - decodedResponseBody, expectedResponseBody) + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] } + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) + } diff --git a/pkg/app.go b/pkg/app.go index eb672b1f..d94d56b5 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -495,3 +495,10 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) return BeeApp } + +// InsertFilterChain adds a FilterFunc built by filterChain. +// This filter will be executed before all filters. +func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) + return BeeApp +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go index f51cc109..e30f7211 100644 --- a/pkg/controller_test.go +++ b/pkg/controller_test.go @@ -19,6 +19,8 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/astaxie/beego/pkg/context" "os" "path/filepath" @@ -125,8 +127,10 @@ func TestGetUint64(t *testing.T) { } func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths") defer os.RemoveAll(dir1) defer os.RemoveAll(dir2) diff --git a/pkg/filter.go b/pkg/filter.go index 4e212e06..543d7901 100644 --- a/pkg/filter.go +++ b/pkg/filter.go @@ -14,10 +14,19 @@ package beego -import "github.com/astaxie/beego/pkg/context" +import ( + "strings" + + "github.com/astaxie/beego/pkg/context" +) + +// FilterChain is different from pure FilterFunc +// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned +// And all those FilterChain will be invoked before other FilterFunc +type FilterChain func(next FilterFunc) FilterFunc // FilterFunc defines a filter function which is invoked before the controller handler is executed. -type FilterFunc func(*context.Context) +type FilterFunc func(ctx *context.Context) // FilterRouter defines a filter operation which is invoked before the controller handler is executed. // It can match the URL against a pattern, and execute a filter function @@ -30,6 +39,55 @@ type FilterRouter struct { resetParams bool } +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !routerCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return mr +} + +// filter will check whether we need to execute the filter logic +// return (started, done) +func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) { + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + if f.resetParams { + preFilterParams = ctx.Input.Params() + } + if ok := f.ValidRouter(urlPath, ctx); ok { + f.filterFunc(ctx) + if f.resetParams { + ctx.Input.ResetParams() + for k, v := range preFilterParams { + ctx.Input.SetParam(k, v) + } + } + } + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + return false, false +} + // ValidRouter checks if the current request is matched by this filter. // If the request is matched, the values of the URL parameters defined // by the filter pattern are also returned. diff --git a/pkg/filter_chain_test.go b/pkg/filter_chain_test.go new file mode 100644 index 00000000..42397a60 --- /dev/null +++ b/pkg/filter_chain_test.go @@ -0,0 +1,49 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/context" +) + +func TestControllerRegister_InsertFilterChain(t *testing.T) { + + InsertFilterChain("/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", "filter-chain") + next(ctx) + } + }) + + ns := NewNamespace("/chain") + + ns.Get("/*", func(ctx *context.Context) { + ctx.Output.Body([]byte("hello")) + }) + + + r, _ := http.NewRequest("GET", "/chain/user", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.ServeHTTP(w, r) + + assert.Equal(t, "filter-chain", w.Header().Get("filter")) +} diff --git a/pkg/router.go b/pkg/router.go index 995fb767..b0c23003 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -134,11 +134,14 @@ type ControllerRegister struct { enableFilter bool filters [FinishRouter + 1][]*FilterRouter pool sync.Pool + + // the filter created by FilterChain + chainRoot *FilterRouter } // NewControllerRegister returns a new ControllerRegister. func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ + res := &ControllerRegister{ routers: make(map[string]*Tree), policies: make(map[string]*Tree), pool: sync.Pool{ @@ -147,6 +150,8 @@ func NewControllerRegister() *ControllerRegister { }, }, } + res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + return res } // Add controller handler and pattern rules to ControllerRegister. @@ -489,27 +494,28 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) + mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) return p.insertFilterRouter(pos, mr) } +// InsertFilterChain is similar to InsertFilter, +// but it will using chainRoot.filterFunc as input to build a new filterFunc +// for example, assume that chainRoot is funcA +// and we add new FilterChain +// fc := func(next) { +// return func(ctx) { +// // do something +// next(ctx) +// // do something +// } +// } +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) { + root := p.chainRoot + filterFunc := chain(root.filterFunc) + p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) +} + + // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { if pos < BeforeStatic || pos > FinishRouter { @@ -668,23 +674,9 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { var preFilterParams map[string]string for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true + b, done := filterR.filter(context, urlPath, preFilterParams) + if done { + return b } } return false @@ -692,7 +684,20 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str // Implement http.Handler interface. func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + + ctx := p.GetContext() + + ctx.Reset(rw, r) + defer p.GiveBackContext(ctx) + + var preFilterParams map[string]string + p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams) +} + +func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { startTime := time.Now() + r := ctx.Request + rw := ctx.ResponseWriter.ResponseWriter var ( runRouter reflect.Type findRouter bool @@ -701,108 +706,100 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) routerInfo *ControllerInfo isRunnable bool ) - context := p.GetContext() - context.Reset(rw, r) - - defer p.GiveBackContext(context) if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) + defer BConfig.RecoverFunc(ctx) } - context.Output.EnableGzip = BConfig.EnableGzip + ctx.Output.EnableGzip = BConfig.EnableGzip if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) + ctx.Output.Header("Server", BConfig.ServerName) } - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } + urlPath := p.getUrlPath(ctx) // filter wrong http method if !HTTPMETHOD[r.Method] { - exception("405", context) + exception("405", ctx) goto Admin } // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) { goto Admin } - serverStaticRouter(context) + serverStaticRouter(ctx) - if context.ResponseWriter.Started { + if ctx.ResponseWriter.Started { findRouter = true goto Admin } if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { + if BConfig.CopyRequestBody && !ctx.Input.IsUpload() { // connection will close if the incoming data are larger (RFC 7231, 6.5.11) if r.ContentLength > BConfig.MaxMemory { logs.Error(errors.New("payload too large")) - exception("413", context) + exception("413", ctx) goto Admin } - context.Input.CopyBody(BConfig.MaxMemory) + ctx.Input.CopyBody(BConfig.MaxMemory) } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory) } // session init if BConfig.WebConfig.Session.SessionOn { var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) - exception("503", context) + exception("503", ctx) goto Admin } defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) + if ctx.Input.CruSession != nil { + ctx.Input.CruSession.SessionRelease(rw) } }() } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) { goto Admin } // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { + if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" { findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController + runMethod = ctx.Input.RunMethod + runRouter = ctx.Input.RunController } else { - routerInfo, findRouter = p.FindRouter(context) + routerInfo, findRouter = p.FindRouter(ctx) } // if no matches to url, throw a not found exception if !findRouter { - exception("404", context) + exception("404", ctx) goto Admin } - if splat := context.Input.Param(":splat"); splat != "" { + if splat := ctx.Input.Param(":splat"); splat != "" { for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) + ctx.Input.SetParam(strconv.Itoa(k), v) } } if routerInfo != nil { // store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) + ctx.Input.SetData("RouterPattern", routerInfo.pattern) } // execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) { goto Admin } // check policies - if p.execPolicy(context, urlPath) { + if p.execPolicy(ctx, urlPath) { goto Admin } @@ -810,22 +807,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { isRunnable = true - routerInfo.runFunction(context) + routerInfo.runFunction(ctx) } else { - exception("405", context) + exception("405", ctx) goto Admin } } else if routerInfo.routerType == routerTypeHandler { isRunnable = true - routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request) } else { runRouter = routerInfo.controllerType methodParams = routerInfo.methodParams method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut { method = http.MethodPut } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete { method = http.MethodDelete } if m, ok := routerInfo.methods[method]; ok { @@ -854,7 +851,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) + execController.Init(ctx, runRouter.Name(), runMethod, execController) // call prepare function execController.Prepare() @@ -863,14 +860,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if BConfig.WebConfig.EnableXSRF { execController.XSRFToken() if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) { execController.CheckXSRFCookie() } } execController.URLMapping() - if !context.ResponseWriter.Started { + if !ctx.ResponseWriter.Started { // exec main logic switch runMethod { case http.MethodGet: @@ -893,18 +890,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !execController.HandlerFunc(runMethod) { vc := reflect.ValueOf(execController) method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) + in := param.ConvertParams(methodParams, method.Type(), ctx) out := method.Call(in) // For backward compatibility we only handle response if we had incoming methodParams if methodParams != nil { - p.handleParamResponse(context, execController, out) + p.handleParamResponse(ctx, execController, out) } } } // render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { + if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { logs.Error(err) @@ -918,26 +915,26 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) { goto Admin } - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) { goto Admin } Admin: // admin module record QPS - statusCode := context.ResponseWriter.Status + statusCode := ctx.ResponseWriter.Status if statusCode == 0 { statusCode = 200 } - LogAccess(context, &startTime, statusCode) + LogAccess(ctx, &startTime, statusCode) timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur + ctx.ResponseWriter.Elapsed = timeDur if BConfig.Listen.EnableAdmin { pattern := "" if routerInfo != nil { @@ -956,7 +953,7 @@ Admin: if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { match := map[bool]string{true: "match", false: "nomatch"} devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), + ctx.Input.IP(), logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), timeDur.String(), match[findRouter], @@ -969,11 +966,19 @@ Admin: logs.Debug(devInfo) } // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) } } +func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string { + urlPath := ctx.Request.URL.Path + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + return urlPath +} + func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { // looping in reverse order for the case when both error and value are returned and error sets the response status code for i := len(results) - 1; i >= 0; i-- { diff --git a/pkg/template_test.go b/pkg/template_test.go index 590a7bd6..af948190 100644 --- a/pkg/template_test.go +++ b/pkg/template_test.go @@ -16,12 +16,15 @@ package beego import ( "bytes" - "github.com/astaxie/beego/pkg/testdata" - "github.com/elazarl/go-bindata-assetfs" "net/http" "os" "path/filepath" "testing" + + "github.com/elazarl/go-bindata-assetfs" + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/testdata" ) var header = `{{define "header"}} @@ -46,7 +49,9 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", @@ -56,7 +61,8 @@ func TestTemplate(t *testing.T) { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -107,7 +113,9 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + dir := filepath.Join(wkdir, "_beeTmp") //Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { @@ -218,7 +226,10 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" + wkdir, err := os.Getwd() + assert.Nil(t, err) + + dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", @@ -226,17 +237,22 @@ func TestTemplateLayout(t *testing.T) { if err := os.MkdirAll(dir, 0777); err != nil { t.Fatal(err) } + for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { if k == 0 { - f.WriteString(add) + _, writeErr := f.WriteString(add) + assert.Nil(t, writeErr) } else if k == 1 { - f.WriteString(layoutBlog) + _, writeErr := f.WriteString(layoutBlog) + assert.Nil(t, writeErr) } - f.Close() + clErr := f.Close() + assert.Nil(t, clErr) } } if err := AddViewPath(dir); err != nil { @@ -247,6 +263,7 @@ func TestTemplateLayout(t *testing.T) { t.Fatalf("should be 2 but got %v", len(beeTemplates)) } out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { t.Fatal(err) }