mirror of
https://github.com/astaxie/beego.git
synced 2024-11-21 21:40:55 +00:00
support filter chain
This commit is contained in:
parent
9e1346ef4d
commit
79ffef90e3
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,3 +4,6 @@
|
|||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
beego.iml
|
beego.iml
|
||||||
|
|
||||||
|
_beeTmp
|
||||||
|
_beeTmp2
|
||||||
|
@ -6,10 +6,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/astaxie/beego/toolbox"
|
"github.com/astaxie/beego/toolbox"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) {
|
|||||||
t.Errorf("invalid response map length: got %d want %d",
|
t.Errorf("invalid response map length: got %d want %d",
|
||||||
len(decodedResponseBody), len(expectedResponseBody))
|
len(decodedResponseBody), len(expectedResponseBody))
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody))
|
||||||
|
assert.Equal(t, 2, len(decodedResponseBody))
|
||||||
|
|
||||||
if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) {
|
var database, cache map[string]interface{}
|
||||||
t.Errorf("handler returned unexpected body: got %v want %v",
|
if decodedResponseBody[0]["message"] == "database" {
|
||||||
decodedResponseBody, expectedResponseBody)
|
database = decodedResponseBody[0]
|
||||||
|
cache = decodedResponseBody[1]
|
||||||
|
} else {
|
||||||
|
database = decodedResponseBody[1]
|
||||||
|
cache = decodedResponseBody[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedResponseBody[0], database)
|
||||||
|
assert.Equal(t, expectedResponseBody[1], cache)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -495,3 +495,10 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A
|
|||||||
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
|
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
|
||||||
return BeeApp
|
return BeeApp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertFilterChain adds a FilterFunc built by filterChain.
|
||||||
|
// This filter will be executed before all filters.
|
||||||
|
func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App {
|
||||||
|
BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...)
|
||||||
|
return BeeApp
|
||||||
|
}
|
||||||
|
@ -19,6 +19,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/astaxie/beego/pkg/context"
|
"github.com/astaxie/beego/pkg/context"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -125,8 +127,10 @@ func TestGetUint64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAdditionalViewPaths(t *testing.T) {
|
func TestAdditionalViewPaths(t *testing.T) {
|
||||||
dir1 := "_beeTmp"
|
wkdir, err := os.Getwd()
|
||||||
dir2 := "_beeTmp2"
|
assert.Nil(t, err)
|
||||||
|
dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths")
|
||||||
|
dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths")
|
||||||
defer os.RemoveAll(dir1)
|
defer os.RemoveAll(dir1)
|
||||||
defer os.RemoveAll(dir2)
|
defer os.RemoveAll(dir2)
|
||||||
|
|
||||||
|
@ -14,10 +14,19 @@
|
|||||||
|
|
||||||
package beego
|
package beego
|
||||||
|
|
||||||
import "github.com/astaxie/beego/pkg/context"
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/pkg/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FilterChain is different from pure FilterFunc
|
||||||
|
// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned
|
||||||
|
// And all those FilterChain will be invoked before other FilterFunc
|
||||||
|
type FilterChain func(next FilterFunc) FilterFunc
|
||||||
|
|
||||||
// FilterFunc defines a filter function which is invoked before the controller handler is executed.
|
// FilterFunc defines a filter function which is invoked before the controller handler is executed.
|
||||||
type FilterFunc func(*context.Context)
|
type FilterFunc func(ctx *context.Context)
|
||||||
|
|
||||||
// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
|
// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
|
||||||
// It can match the URL against a pattern, and execute a filter function
|
// It can match the URL against a pattern, and execute a filter function
|
||||||
@ -30,6 +39,55 @@ type FilterRouter struct {
|
|||||||
resetParams bool
|
resetParams bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// params is for:
|
||||||
|
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
|
||||||
|
// 2. determining whether or not params need to be reset.
|
||||||
|
func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter {
|
||||||
|
mr := &FilterRouter{
|
||||||
|
tree: NewTree(),
|
||||||
|
pattern: pattern,
|
||||||
|
filterFunc: filter,
|
||||||
|
returnOnOutput: true,
|
||||||
|
}
|
||||||
|
if !routerCaseSensitive {
|
||||||
|
mr.pattern = strings.ToLower(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
paramsLen := len(params)
|
||||||
|
if paramsLen > 0 {
|
||||||
|
mr.returnOnOutput = params[0]
|
||||||
|
}
|
||||||
|
if paramsLen > 1 {
|
||||||
|
mr.resetParams = params[1]
|
||||||
|
}
|
||||||
|
mr.tree.AddRouter(pattern, true)
|
||||||
|
return mr
|
||||||
|
}
|
||||||
|
|
||||||
|
// filter will check whether we need to execute the filter logic
|
||||||
|
// return (started, done)
|
||||||
|
func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) {
|
||||||
|
if f.returnOnOutput && ctx.ResponseWriter.Started {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
if f.resetParams {
|
||||||
|
preFilterParams = ctx.Input.Params()
|
||||||
|
}
|
||||||
|
if ok := f.ValidRouter(urlPath, ctx); ok {
|
||||||
|
f.filterFunc(ctx)
|
||||||
|
if f.resetParams {
|
||||||
|
ctx.Input.ResetParams()
|
||||||
|
for k, v := range preFilterParams {
|
||||||
|
ctx.Input.SetParam(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if f.returnOnOutput && ctx.ResponseWriter.Started {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
// ValidRouter checks if the current request is matched by this filter.
|
// ValidRouter checks if the current request is matched by this filter.
|
||||||
// If the request is matched, the values of the URL parameters defined
|
// If the request is matched, the values of the URL parameters defined
|
||||||
// by the filter pattern are also returned.
|
// by the filter pattern are also returned.
|
||||||
|
49
pkg/filter_chain_test.go
Normal file
49
pkg/filter_chain_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright 2020
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package beego
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/pkg/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestControllerRegister_InsertFilterChain(t *testing.T) {
|
||||||
|
|
||||||
|
InsertFilterChain("/*", func(next FilterFunc) FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
ctx.Output.Header("filter", "filter-chain")
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ns := NewNamespace("/chain")
|
||||||
|
|
||||||
|
ns.Get("/*", func(ctx *context.Context) {
|
||||||
|
ctx.Output.Body([]byte("hello"))
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
r, _ := http.NewRequest("GET", "/chain/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
|
||||||
|
}
|
185
pkg/router.go
185
pkg/router.go
@ -134,11 +134,14 @@ type ControllerRegister struct {
|
|||||||
enableFilter bool
|
enableFilter bool
|
||||||
filters [FinishRouter + 1][]*FilterRouter
|
filters [FinishRouter + 1][]*FilterRouter
|
||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
|
|
||||||
|
// the filter created by FilterChain
|
||||||
|
chainRoot *FilterRouter
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewControllerRegister returns a new ControllerRegister.
|
// NewControllerRegister returns a new ControllerRegister.
|
||||||
func NewControllerRegister() *ControllerRegister {
|
func NewControllerRegister() *ControllerRegister {
|
||||||
return &ControllerRegister{
|
res := &ControllerRegister{
|
||||||
routers: make(map[string]*Tree),
|
routers: make(map[string]*Tree),
|
||||||
policies: make(map[string]*Tree),
|
policies: make(map[string]*Tree),
|
||||||
pool: sync.Pool{
|
pool: sync.Pool{
|
||||||
@ -147,6 +150,8 @@ func NewControllerRegister() *ControllerRegister {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
res.chainRoot = newFilterRouter("/*", false, res.serveHttp)
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add controller handler and pattern rules to ControllerRegister.
|
// Add controller handler and pattern rules to ControllerRegister.
|
||||||
@ -489,27 +494,28 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
|
|||||||
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
|
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
|
||||||
// 2. determining whether or not params need to be reset.
|
// 2. determining whether or not params need to be reset.
|
||||||
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
|
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
|
||||||
mr := &FilterRouter{
|
mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...)
|
||||||
tree: NewTree(),
|
|
||||||
pattern: pattern,
|
|
||||||
filterFunc: filter,
|
|
||||||
returnOnOutput: true,
|
|
||||||
}
|
|
||||||
if !BConfig.RouterCaseSensitive {
|
|
||||||
mr.pattern = strings.ToLower(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
paramsLen := len(params)
|
|
||||||
if paramsLen > 0 {
|
|
||||||
mr.returnOnOutput = params[0]
|
|
||||||
}
|
|
||||||
if paramsLen > 1 {
|
|
||||||
mr.resetParams = params[1]
|
|
||||||
}
|
|
||||||
mr.tree.AddRouter(pattern, true)
|
|
||||||
return p.insertFilterRouter(pos, mr)
|
return p.insertFilterRouter(pos, mr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertFilterChain is similar to InsertFilter,
|
||||||
|
// but it will using chainRoot.filterFunc as input to build a new filterFunc
|
||||||
|
// for example, assume that chainRoot is funcA
|
||||||
|
// and we add new FilterChain
|
||||||
|
// fc := func(next) {
|
||||||
|
// return func(ctx) {
|
||||||
|
// // do something
|
||||||
|
// next(ctx)
|
||||||
|
// // do something
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) {
|
||||||
|
root := p.chainRoot
|
||||||
|
filterFunc := chain(root.filterFunc)
|
||||||
|
p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// add Filter into
|
// add Filter into
|
||||||
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
|
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
|
||||||
if pos < BeforeStatic || pos > FinishRouter {
|
if pos < BeforeStatic || pos > FinishRouter {
|
||||||
@ -668,23 +674,9 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
|
|||||||
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
|
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
|
||||||
var preFilterParams map[string]string
|
var preFilterParams map[string]string
|
||||||
for _, filterR := range p.filters[pos] {
|
for _, filterR := range p.filters[pos] {
|
||||||
if filterR.returnOnOutput && context.ResponseWriter.Started {
|
b, done := filterR.filter(context, urlPath, preFilterParams)
|
||||||
return true
|
if done {
|
||||||
}
|
return b
|
||||||
if filterR.resetParams {
|
|
||||||
preFilterParams = context.Input.Params()
|
|
||||||
}
|
|
||||||
if ok := filterR.ValidRouter(urlPath, context); ok {
|
|
||||||
filterR.filterFunc(context)
|
|
||||||
if filterR.resetParams {
|
|
||||||
context.Input.ResetParams()
|
|
||||||
for k, v := range preFilterParams {
|
|
||||||
context.Input.SetParam(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if filterR.returnOnOutput && context.ResponseWriter.Started {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@ -692,7 +684,20 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
|
|||||||
|
|
||||||
// Implement http.Handler interface.
|
// Implement http.Handler interface.
|
||||||
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
ctx := p.GetContext()
|
||||||
|
|
||||||
|
ctx.Reset(rw, r)
|
||||||
|
defer p.GiveBackContext(ctx)
|
||||||
|
|
||||||
|
var preFilterParams map[string]string
|
||||||
|
p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
r := ctx.Request
|
||||||
|
rw := ctx.ResponseWriter.ResponseWriter
|
||||||
var (
|
var (
|
||||||
runRouter reflect.Type
|
runRouter reflect.Type
|
||||||
findRouter bool
|
findRouter bool
|
||||||
@ -701,108 +706,100 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
routerInfo *ControllerInfo
|
routerInfo *ControllerInfo
|
||||||
isRunnable bool
|
isRunnable bool
|
||||||
)
|
)
|
||||||
context := p.GetContext()
|
|
||||||
|
|
||||||
context.Reset(rw, r)
|
|
||||||
|
|
||||||
defer p.GiveBackContext(context)
|
|
||||||
if BConfig.RecoverFunc != nil {
|
if BConfig.RecoverFunc != nil {
|
||||||
defer BConfig.RecoverFunc(context)
|
defer BConfig.RecoverFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
context.Output.EnableGzip = BConfig.EnableGzip
|
ctx.Output.EnableGzip = BConfig.EnableGzip
|
||||||
|
|
||||||
if BConfig.RunMode == DEV {
|
if BConfig.RunMode == DEV {
|
||||||
context.Output.Header("Server", BConfig.ServerName)
|
ctx.Output.Header("Server", BConfig.ServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
var urlPath = r.URL.Path
|
urlPath := p.getUrlPath(ctx)
|
||||||
|
|
||||||
if !BConfig.RouterCaseSensitive {
|
|
||||||
urlPath = strings.ToLower(urlPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// filter wrong http method
|
// filter wrong http method
|
||||||
if !HTTPMETHOD[r.Method] {
|
if !HTTPMETHOD[r.Method] {
|
||||||
exception("405", context)
|
exception("405", ctx)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
// filter for static file
|
// filter for static file
|
||||||
if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) {
|
if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
serverStaticRouter(context)
|
serverStaticRouter(ctx)
|
||||||
|
|
||||||
if context.ResponseWriter.Started {
|
if ctx.ResponseWriter.Started {
|
||||||
findRouter = true
|
findRouter = true
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||||
if BConfig.CopyRequestBody && !context.Input.IsUpload() {
|
if BConfig.CopyRequestBody && !ctx.Input.IsUpload() {
|
||||||
// connection will close if the incoming data are larger (RFC 7231, 6.5.11)
|
// connection will close if the incoming data are larger (RFC 7231, 6.5.11)
|
||||||
if r.ContentLength > BConfig.MaxMemory {
|
if r.ContentLength > BConfig.MaxMemory {
|
||||||
logs.Error(errors.New("payload too large"))
|
logs.Error(errors.New("payload too large"))
|
||||||
exception("413", context)
|
exception("413", ctx)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
context.Input.CopyBody(BConfig.MaxMemory)
|
ctx.Input.CopyBody(BConfig.MaxMemory)
|
||||||
}
|
}
|
||||||
context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
|
ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
|
||||||
}
|
}
|
||||||
|
|
||||||
// session init
|
// session init
|
||||||
if BConfig.WebConfig.Session.SessionOn {
|
if BConfig.WebConfig.Session.SessionOn {
|
||||||
var err error
|
var err error
|
||||||
context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
|
ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
exception("503", context)
|
exception("503", ctx)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if context.Input.CruSession != nil {
|
if ctx.Input.CruSession != nil {
|
||||||
context.Input.CruSession.SessionRelease(rw)
|
ctx.Input.CruSession.SessionRelease(rw)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
|
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
// User can define RunController and RunMethod in filter
|
// User can define RunController and RunMethod in filter
|
||||||
if context.Input.RunController != nil && context.Input.RunMethod != "" {
|
if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" {
|
||||||
findRouter = true
|
findRouter = true
|
||||||
runMethod = context.Input.RunMethod
|
runMethod = ctx.Input.RunMethod
|
||||||
runRouter = context.Input.RunController
|
runRouter = ctx.Input.RunController
|
||||||
} else {
|
} else {
|
||||||
routerInfo, findRouter = p.FindRouter(context)
|
routerInfo, findRouter = p.FindRouter(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no matches to url, throw a not found exception
|
// if no matches to url, throw a not found exception
|
||||||
if !findRouter {
|
if !findRouter {
|
||||||
exception("404", context)
|
exception("404", ctx)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
if splat := context.Input.Param(":splat"); splat != "" {
|
if splat := ctx.Input.Param(":splat"); splat != "" {
|
||||||
for k, v := range strings.Split(splat, "/") {
|
for k, v := range strings.Split(splat, "/") {
|
||||||
context.Input.SetParam(strconv.Itoa(k), v)
|
ctx.Input.SetParam(strconv.Itoa(k), v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if routerInfo != nil {
|
if routerInfo != nil {
|
||||||
// store router pattern into context
|
// store router pattern into context
|
||||||
context.Input.SetData("RouterPattern", routerInfo.pattern)
|
ctx.Input.SetData("RouterPattern", routerInfo.pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute middleware filters
|
// execute middleware filters
|
||||||
if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
|
if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
// check policies
|
// check policies
|
||||||
if p.execPolicy(context, urlPath) {
|
if p.execPolicy(ctx, urlPath) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -810,22 +807,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
if routerInfo.routerType == routerTypeRESTFul {
|
if routerInfo.routerType == routerTypeRESTFul {
|
||||||
if _, ok := routerInfo.methods[r.Method]; ok {
|
if _, ok := routerInfo.methods[r.Method]; ok {
|
||||||
isRunnable = true
|
isRunnable = true
|
||||||
routerInfo.runFunction(context)
|
routerInfo.runFunction(ctx)
|
||||||
} else {
|
} else {
|
||||||
exception("405", context)
|
exception("405", ctx)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
} else if routerInfo.routerType == routerTypeHandler {
|
} else if routerInfo.routerType == routerTypeHandler {
|
||||||
isRunnable = true
|
isRunnable = true
|
||||||
routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request)
|
routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request)
|
||||||
} else {
|
} else {
|
||||||
runRouter = routerInfo.controllerType
|
runRouter = routerInfo.controllerType
|
||||||
methodParams = routerInfo.methodParams
|
methodParams = routerInfo.methodParams
|
||||||
method := r.Method
|
method := r.Method
|
||||||
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut {
|
if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut {
|
||||||
method = http.MethodPut
|
method = http.MethodPut
|
||||||
}
|
}
|
||||||
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
|
if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete {
|
||||||
method = http.MethodDelete
|
method = http.MethodDelete
|
||||||
}
|
}
|
||||||
if m, ok := routerInfo.methods[method]; ok {
|
if m, ok := routerInfo.methods[method]; ok {
|
||||||
@ -854,7 +851,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// call the controller init function
|
// call the controller init function
|
||||||
execController.Init(context, runRouter.Name(), runMethod, execController)
|
execController.Init(ctx, runRouter.Name(), runMethod, execController)
|
||||||
|
|
||||||
// call prepare function
|
// call prepare function
|
||||||
execController.Prepare()
|
execController.Prepare()
|
||||||
@ -863,14 +860,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
if BConfig.WebConfig.EnableXSRF {
|
if BConfig.WebConfig.EnableXSRF {
|
||||||
execController.XSRFToken()
|
execController.XSRFToken()
|
||||||
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
|
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
|
||||||
(r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) {
|
(r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) {
|
||||||
execController.CheckXSRFCookie()
|
execController.CheckXSRFCookie()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
execController.URLMapping()
|
execController.URLMapping()
|
||||||
|
|
||||||
if !context.ResponseWriter.Started {
|
if !ctx.ResponseWriter.Started {
|
||||||
// exec main logic
|
// exec main logic
|
||||||
switch runMethod {
|
switch runMethod {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
@ -893,18 +890,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
if !execController.HandlerFunc(runMethod) {
|
if !execController.HandlerFunc(runMethod) {
|
||||||
vc := reflect.ValueOf(execController)
|
vc := reflect.ValueOf(execController)
|
||||||
method := vc.MethodByName(runMethod)
|
method := vc.MethodByName(runMethod)
|
||||||
in := param.ConvertParams(methodParams, method.Type(), context)
|
in := param.ConvertParams(methodParams, method.Type(), ctx)
|
||||||
out := method.Call(in)
|
out := method.Call(in)
|
||||||
|
|
||||||
// For backward compatibility we only handle response if we had incoming methodParams
|
// For backward compatibility we only handle response if we had incoming methodParams
|
||||||
if methodParams != nil {
|
if methodParams != nil {
|
||||||
p.handleParamResponse(context, execController, out)
|
p.handleParamResponse(ctx, execController, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// render template
|
// render template
|
||||||
if !context.ResponseWriter.Started && context.Output.Status == 0 {
|
if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 {
|
||||||
if BConfig.WebConfig.AutoRender {
|
if BConfig.WebConfig.AutoRender {
|
||||||
if err := execController.Render(); err != nil {
|
if err := execController.Render(); err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
@ -918,26 +915,26 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute middleware filters
|
// execute middleware filters
|
||||||
if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
|
if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) {
|
if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) {
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
Admin:
|
Admin:
|
||||||
// admin module record QPS
|
// admin module record QPS
|
||||||
|
|
||||||
statusCode := context.ResponseWriter.Status
|
statusCode := ctx.ResponseWriter.Status
|
||||||
if statusCode == 0 {
|
if statusCode == 0 {
|
||||||
statusCode = 200
|
statusCode = 200
|
||||||
}
|
}
|
||||||
|
|
||||||
LogAccess(context, &startTime, statusCode)
|
LogAccess(ctx, &startTime, statusCode)
|
||||||
|
|
||||||
timeDur := time.Since(startTime)
|
timeDur := time.Since(startTime)
|
||||||
context.ResponseWriter.Elapsed = timeDur
|
ctx.ResponseWriter.Elapsed = timeDur
|
||||||
if BConfig.Listen.EnableAdmin {
|
if BConfig.Listen.EnableAdmin {
|
||||||
pattern := ""
|
pattern := ""
|
||||||
if routerInfo != nil {
|
if routerInfo != nil {
|
||||||
@ -956,7 +953,7 @@ Admin:
|
|||||||
if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
|
if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
|
||||||
match := map[bool]string{true: "match", false: "nomatch"}
|
match := map[bool]string{true: "match", false: "nomatch"}
|
||||||
devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
|
devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
|
||||||
context.Input.IP(),
|
ctx.Input.IP(),
|
||||||
logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
|
logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
|
||||||
timeDur.String(),
|
timeDur.String(),
|
||||||
match[findRouter],
|
match[findRouter],
|
||||||
@ -969,11 +966,19 @@ Admin:
|
|||||||
logs.Debug(devInfo)
|
logs.Debug(devInfo)
|
||||||
}
|
}
|
||||||
// Call WriteHeader if status code has been set changed
|
// Call WriteHeader if status code has been set changed
|
||||||
if context.Output.Status != 0 {
|
if ctx.Output.Status != 0 {
|
||||||
context.ResponseWriter.WriteHeader(context.Output.Status)
|
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string {
|
||||||
|
urlPath := ctx.Request.URL.Path
|
||||||
|
if !BConfig.RouterCaseSensitive {
|
||||||
|
urlPath = strings.ToLower(urlPath)
|
||||||
|
}
|
||||||
|
return urlPath
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
|
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
|
||||||
// looping in reverse order for the case when both error and value are returned and error sets the response status code
|
// looping in reverse order for the case when both error and value are returned and error sets the response status code
|
||||||
for i := len(results) - 1; i >= 0; i-- {
|
for i := len(results) - 1; i >= 0; i-- {
|
||||||
|
@ -16,12 +16,15 @@ package beego
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/astaxie/beego/pkg/testdata"
|
|
||||||
"github.com/elazarl/go-bindata-assetfs"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/elazarl/go-bindata-assetfs"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/pkg/testdata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var header = `{{define "header"}}
|
var header = `{{define "header"}}
|
||||||
@ -46,7 +49,9 @@ var block = `{{define "block"}}
|
|||||||
{{end}}`
|
{{end}}`
|
||||||
|
|
||||||
func TestTemplate(t *testing.T) {
|
func TestTemplate(t *testing.T) {
|
||||||
dir := "_beeTmp"
|
wkdir, err := os.Getwd()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate")
|
||||||
files := []string{
|
files := []string{
|
||||||
"header.tpl",
|
"header.tpl",
|
||||||
"index.tpl",
|
"index.tpl",
|
||||||
@ -56,7 +61,8 @@ func TestTemplate(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
for k, name := range files {
|
for k, name := range files {
|
||||||
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
|
dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
|
||||||
|
assert.Nil(t, dirErr)
|
||||||
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
|
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else {
|
} else {
|
||||||
@ -107,7 +113,9 @@ var user = `<!DOCTYPE html>
|
|||||||
`
|
`
|
||||||
|
|
||||||
func TestRelativeTemplate(t *testing.T) {
|
func TestRelativeTemplate(t *testing.T) {
|
||||||
dir := "_beeTmp"
|
wkdir, err := os.Getwd()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
dir := filepath.Join(wkdir, "_beeTmp")
|
||||||
|
|
||||||
//Just add dir to known viewPaths
|
//Just add dir to known viewPaths
|
||||||
if err := AddViewPath(dir); err != nil {
|
if err := AddViewPath(dir); err != nil {
|
||||||
@ -218,7 +226,10 @@ var output = `<!DOCTYPE html>
|
|||||||
`
|
`
|
||||||
|
|
||||||
func TestTemplateLayout(t *testing.T) {
|
func TestTemplateLayout(t *testing.T) {
|
||||||
dir := "_beeTmp"
|
wkdir, err := os.Getwd()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout")
|
||||||
files := []string{
|
files := []string{
|
||||||
"add.tpl",
|
"add.tpl",
|
||||||
"layout_blog.tpl",
|
"layout_blog.tpl",
|
||||||
@ -226,17 +237,22 @@ func TestTemplateLayout(t *testing.T) {
|
|||||||
if err := os.MkdirAll(dir, 0777); err != nil {
|
if err := os.MkdirAll(dir, 0777); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, name := range files {
|
for k, name := range files {
|
||||||
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
|
dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
|
||||||
|
assert.Nil(t, dirErr)
|
||||||
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
|
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else {
|
} else {
|
||||||
if k == 0 {
|
if k == 0 {
|
||||||
f.WriteString(add)
|
_, writeErr := f.WriteString(add)
|
||||||
|
assert.Nil(t, writeErr)
|
||||||
} else if k == 1 {
|
} else if k == 1 {
|
||||||
f.WriteString(layoutBlog)
|
_, writeErr := f.WriteString(layoutBlog)
|
||||||
|
assert.Nil(t, writeErr)
|
||||||
}
|
}
|
||||||
f.Close()
|
clErr := f.Close()
|
||||||
|
assert.Nil(t, clErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := AddViewPath(dir); err != nil {
|
if err := AddViewPath(dir); err != nil {
|
||||||
@ -247,6 +263,7 @@ func TestTemplateLayout(t *testing.T) {
|
|||||||
t.Fatalf("should be 2 but got %v", len(beeTemplates))
|
t.Fatalf("should be 2 but got %v", len(beeTemplates))
|
||||||
}
|
}
|
||||||
out := bytes.NewBufferString("")
|
out := bytes.NewBufferString("")
|
||||||
|
|
||||||
if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
|
if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user