1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 03:10:58 +00:00

support filter chain

This commit is contained in:
Ming Deng 2020-07-31 21:43:11 +08:00
parent 9e1346ef4d
commit 79ffef90e3
8 changed files with 261 additions and 108 deletions

3
.gitignore vendored
View File

@ -4,3 +4,6 @@
*.swp *.swp
*.swo *.swo
beego.iml beego.iml
_beeTmp
_beeTmp2

View File

@ -6,10 +6,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
) )
@ -230,10 +231,19 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) {
t.Errorf("invalid response map length: got %d want %d", t.Errorf("invalid response map length: got %d want %d",
len(decodedResponseBody), len(expectedResponseBody)) len(decodedResponseBody), len(expectedResponseBody))
} }
assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody))
assert.Equal(t, 2, len(decodedResponseBody))
if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { var database, cache map[string]interface{}
t.Errorf("handler returned unexpected body: got %v want %v", if decodedResponseBody[0]["message"] == "database" {
decodedResponseBody, expectedResponseBody) database = decodedResponseBody[0]
cache = decodedResponseBody[1]
} else {
database = decodedResponseBody[1]
cache = decodedResponseBody[0]
} }
assert.Equal(t, expectedResponseBody[0], database)
assert.Equal(t, expectedResponseBody[1], cache)
} }

View File

@ -495,3 +495,10 @@ func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *A
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
return BeeApp return BeeApp
} }
// InsertFilterChain adds a FilterFunc built by filterChain.
// This filter will be executed before all filters.
func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App {
BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...)
return BeeApp
}

View File

@ -19,6 +19,8 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/context" "github.com/astaxie/beego/pkg/context"
"os" "os"
"path/filepath" "path/filepath"
@ -125,8 +127,10 @@ func TestGetUint64(t *testing.T) {
} }
func TestAdditionalViewPaths(t *testing.T) { func TestAdditionalViewPaths(t *testing.T) {
dir1 := "_beeTmp" wkdir, err := os.Getwd()
dir2 := "_beeTmp2" assert.Nil(t, err)
dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths")
dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths")
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer os.RemoveAll(dir2) defer os.RemoveAll(dir2)

View File

@ -14,10 +14,19 @@
package beego package beego
import "github.com/astaxie/beego/pkg/context" import (
"strings"
"github.com/astaxie/beego/pkg/context"
)
// FilterChain is different from pure FilterFunc
// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned
// And all those FilterChain will be invoked before other FilterFunc
type FilterChain func(next FilterFunc) FilterFunc
// FilterFunc defines a filter function which is invoked before the controller handler is executed. // FilterFunc defines a filter function which is invoked before the controller handler is executed.
type FilterFunc func(*context.Context) type FilterFunc func(ctx *context.Context)
// FilterRouter defines a filter operation which is invoked before the controller handler is executed. // FilterRouter defines a filter operation which is invoked before the controller handler is executed.
// It can match the URL against a pattern, and execute a filter function // It can match the URL against a pattern, and execute a filter function
@ -30,6 +39,55 @@ type FilterRouter struct {
resetParams bool resetParams bool
} }
// params is for:
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
// 2. determining whether or not params need to be reset.
func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter {
mr := &FilterRouter{
tree: NewTree(),
pattern: pattern,
filterFunc: filter,
returnOnOutput: true,
}
if !routerCaseSensitive {
mr.pattern = strings.ToLower(pattern)
}
paramsLen := len(params)
if paramsLen > 0 {
mr.returnOnOutput = params[0]
}
if paramsLen > 1 {
mr.resetParams = params[1]
}
mr.tree.AddRouter(pattern, true)
return mr
}
// filter will check whether we need to execute the filter logic
// return (started, done)
func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) {
if f.returnOnOutput && ctx.ResponseWriter.Started {
return true, true
}
if f.resetParams {
preFilterParams = ctx.Input.Params()
}
if ok := f.ValidRouter(urlPath, ctx); ok {
f.filterFunc(ctx)
if f.resetParams {
ctx.Input.ResetParams()
for k, v := range preFilterParams {
ctx.Input.SetParam(k, v)
}
}
}
if f.returnOnOutput && ctx.ResponseWriter.Started {
return true, true
}
return false, false
}
// ValidRouter checks if the current request is matched by this filter. // ValidRouter checks if the current request is matched by this filter.
// If the request is matched, the values of the URL parameters defined // If the request is matched, the values of the URL parameters defined
// by the filter pattern are also returned. // by the filter pattern are also returned.

49
pkg/filter_chain_test.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/context"
)
func TestControllerRegister_InsertFilterChain(t *testing.T) {
InsertFilterChain("/*", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("filter", "filter-chain")
next(ctx)
}
})
ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello"))
})
r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
}

View File

@ -134,11 +134,14 @@ type ControllerRegister struct {
enableFilter bool enableFilter bool
filters [FinishRouter + 1][]*FilterRouter filters [FinishRouter + 1][]*FilterRouter
pool sync.Pool pool sync.Pool
// the filter created by FilterChain
chainRoot *FilterRouter
} }
// NewControllerRegister returns a new ControllerRegister. // NewControllerRegister returns a new ControllerRegister.
func NewControllerRegister() *ControllerRegister { func NewControllerRegister() *ControllerRegister {
return &ControllerRegister{ res := &ControllerRegister{
routers: make(map[string]*Tree), routers: make(map[string]*Tree),
policies: make(map[string]*Tree), policies: make(map[string]*Tree),
pool: sync.Pool{ pool: sync.Pool{
@ -147,6 +150,8 @@ func NewControllerRegister() *ControllerRegister {
}, },
}, },
} }
res.chainRoot = newFilterRouter("/*", false, res.serveHttp)
return res
} }
// Add controller handler and pattern rules to ControllerRegister. // Add controller handler and pattern rules to ControllerRegister.
@ -489,27 +494,28 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
// 1. setting the returnOnOutput value (false allows multiple filters to execute) // 1. setting the returnOnOutput value (false allows multiple filters to execute)
// 2. determining whether or not params need to be reset. // 2. determining whether or not params need to be reset.
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := &FilterRouter{ mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...)
tree: NewTree(),
pattern: pattern,
filterFunc: filter,
returnOnOutput: true,
}
if !BConfig.RouterCaseSensitive {
mr.pattern = strings.ToLower(pattern)
}
paramsLen := len(params)
if paramsLen > 0 {
mr.returnOnOutput = params[0]
}
if paramsLen > 1 {
mr.resetParams = params[1]
}
mr.tree.AddRouter(pattern, true)
return p.insertFilterRouter(pos, mr) return p.insertFilterRouter(pos, mr)
} }
// InsertFilterChain is similar to InsertFilter,
// but it will using chainRoot.filterFunc as input to build a new filterFunc
// for example, assume that chainRoot is funcA
// and we add new FilterChain
// fc := func(next) {
// return func(ctx) {
// // do something
// next(ctx)
// // do something
// }
// }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params...bool) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...)
}
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter { if pos < BeforeStatic || pos > FinishRouter {
@ -668,23 +674,9 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
var preFilterParams map[string]string var preFilterParams map[string]string
for _, filterR := range p.filters[pos] { for _, filterR := range p.filters[pos] {
if filterR.returnOnOutput && context.ResponseWriter.Started { b, done := filterR.filter(context, urlPath, preFilterParams)
return true if done {
} return b
if filterR.resetParams {
preFilterParams = context.Input.Params()
}
if ok := filterR.ValidRouter(urlPath, context); ok {
filterR.filterFunc(context)
if filterR.resetParams {
context.Input.ResetParams()
for k, v := range preFilterParams {
context.Input.SetParam(k, v)
}
}
}
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
} }
} }
return false return false
@ -692,7 +684,20 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
// Implement http.Handler interface. // Implement http.Handler interface.
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx := p.GetContext()
ctx.Reset(rw, r)
defer p.GiveBackContext(ctx)
var preFilterParams map[string]string
p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams)
}
func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
startTime := time.Now() startTime := time.Now()
r := ctx.Request
rw := ctx.ResponseWriter.ResponseWriter
var ( var (
runRouter reflect.Type runRouter reflect.Type
findRouter bool findRouter bool
@ -701,108 +706,100 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo *ControllerInfo routerInfo *ControllerInfo
isRunnable bool isRunnable bool
) )
context := p.GetContext()
context.Reset(rw, r)
defer p.GiveBackContext(context)
if BConfig.RecoverFunc != nil { if BConfig.RecoverFunc != nil {
defer BConfig.RecoverFunc(context) defer BConfig.RecoverFunc(ctx)
} }
context.Output.EnableGzip = BConfig.EnableGzip ctx.Output.EnableGzip = BConfig.EnableGzip
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
context.Output.Header("Server", BConfig.ServerName) ctx.Output.Header("Server", BConfig.ServerName)
} }
var urlPath = r.URL.Path urlPath := p.getUrlPath(ctx)
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
}
// filter wrong http method // filter wrong http method
if !HTTPMETHOD[r.Method] { if !HTTPMETHOD[r.Method] {
exception("405", context) exception("405", ctx)
goto Admin goto Admin
} }
// filter for static file // filter for static file
if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) {
goto Admin goto Admin
} }
serverStaticRouter(context) serverStaticRouter(ctx)
if context.ResponseWriter.Started { if ctx.ResponseWriter.Started {
findRouter = true findRouter = true
goto Admin goto Admin
} }
if r.Method != http.MethodGet && r.Method != http.MethodHead { if r.Method != http.MethodGet && r.Method != http.MethodHead {
if BConfig.CopyRequestBody && !context.Input.IsUpload() { if BConfig.CopyRequestBody && !ctx.Input.IsUpload() {
// connection will close if the incoming data are larger (RFC 7231, 6.5.11) // connection will close if the incoming data are larger (RFC 7231, 6.5.11)
if r.ContentLength > BConfig.MaxMemory { if r.ContentLength > BConfig.MaxMemory {
logs.Error(errors.New("payload too large")) logs.Error(errors.New("payload too large"))
exception("413", context) exception("413", ctx)
goto Admin goto Admin
} }
context.Input.CopyBody(BConfig.MaxMemory) ctx.Input.CopyBody(BConfig.MaxMemory)
} }
context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) ctx.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
} }
// session init // session init
if BConfig.WebConfig.Session.SessionOn { if BConfig.WebConfig.Session.SessionOn {
var err error var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil { if err != nil {
logs.Error(err) logs.Error(err)
exception("503", context) exception("503", ctx)
goto Admin goto Admin
} }
defer func() { defer func() {
if context.Input.CruSession != nil { if ctx.Input.CruSession != nil {
context.Input.CruSession.SessionRelease(rw) ctx.Input.CruSession.SessionRelease(rw)
} }
}() }()
} }
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) {
goto Admin goto Admin
} }
// User can define RunController and RunMethod in filter // User can define RunController and RunMethod in filter
if context.Input.RunController != nil && context.Input.RunMethod != "" { if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" {
findRouter = true findRouter = true
runMethod = context.Input.RunMethod runMethod = ctx.Input.RunMethod
runRouter = context.Input.RunController runRouter = ctx.Input.RunController
} else { } else {
routerInfo, findRouter = p.FindRouter(context) routerInfo, findRouter = p.FindRouter(ctx)
} }
// if no matches to url, throw a not found exception // if no matches to url, throw a not found exception
if !findRouter { if !findRouter {
exception("404", context) exception("404", ctx)
goto Admin goto Admin
} }
if splat := context.Input.Param(":splat"); splat != "" { if splat := ctx.Input.Param(":splat"); splat != "" {
for k, v := range strings.Split(splat, "/") { for k, v := range strings.Split(splat, "/") {
context.Input.SetParam(strconv.Itoa(k), v) ctx.Input.SetParam(strconv.Itoa(k), v)
} }
} }
if routerInfo != nil { if routerInfo != nil {
// store router pattern into context // store router pattern into context
context.Input.SetData("RouterPattern", routerInfo.pattern) ctx.Input.SetData("RouterPattern", routerInfo.pattern)
} }
// execute middleware filters // execute middleware filters
if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) {
goto Admin goto Admin
} }
// check policies // check policies
if p.execPolicy(context, urlPath) { if p.execPolicy(ctx, urlPath) {
goto Admin goto Admin
} }
@ -810,22 +807,22 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if routerInfo.routerType == routerTypeRESTFul { if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok { if _, ok := routerInfo.methods[r.Method]; ok {
isRunnable = true isRunnable = true
routerInfo.runFunction(context) routerInfo.runFunction(ctx)
} else { } else {
exception("405", context) exception("405", ctx)
goto Admin goto Admin
} }
} else if routerInfo.routerType == routerTypeHandler { } else if routerInfo.routerType == routerTypeHandler {
isRunnable = true isRunnable = true
routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request)
} else { } else {
runRouter = routerInfo.controllerType runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams methodParams = routerInfo.methodParams
method := r.Method method := r.Method
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut {
method = http.MethodPut method = http.MethodPut
} }
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete {
method = http.MethodDelete method = http.MethodDelete
} }
if m, ok := routerInfo.methods[method]; ok { if m, ok := routerInfo.methods[method]; ok {
@ -854,7 +851,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// call the controller init function // call the controller init function
execController.Init(context, runRouter.Name(), runMethod, execController) execController.Init(ctx, runRouter.Name(), runMethod, execController)
// call prepare function // call prepare function
execController.Prepare() execController.Prepare()
@ -863,14 +860,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if BConfig.WebConfig.EnableXSRF { if BConfig.WebConfig.EnableXSRF {
execController.XSRFToken() execController.XSRFToken()
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
(r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) {
execController.CheckXSRFCookie() execController.CheckXSRFCookie()
} }
} }
execController.URLMapping() execController.URLMapping()
if !context.ResponseWriter.Started { if !ctx.ResponseWriter.Started {
// exec main logic // exec main logic
switch runMethod { switch runMethod {
case http.MethodGet: case http.MethodGet:
@ -893,18 +890,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !execController.HandlerFunc(runMethod) { if !execController.HandlerFunc(runMethod) {
vc := reflect.ValueOf(execController) vc := reflect.ValueOf(execController)
method := vc.MethodByName(runMethod) method := vc.MethodByName(runMethod)
in := param.ConvertParams(methodParams, method.Type(), context) in := param.ConvertParams(methodParams, method.Type(), ctx)
out := method.Call(in) out := method.Call(in)
// For backward compatibility we only handle response if we had incoming methodParams // For backward compatibility we only handle response if we had incoming methodParams
if methodParams != nil { if methodParams != nil {
p.handleParamResponse(context, execController, out) p.handleParamResponse(ctx, execController, out)
} }
} }
} }
// render template // render template
if !context.ResponseWriter.Started && context.Output.Status == 0 { if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
logs.Error(err) logs.Error(err)
@ -918,26 +915,26 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// execute middleware filters // execute middleware filters
if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) {
goto Admin goto Admin
} }
if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) {
goto Admin goto Admin
} }
Admin: Admin:
// admin module record QPS // admin module record QPS
statusCode := context.ResponseWriter.Status statusCode := ctx.ResponseWriter.Status
if statusCode == 0 { if statusCode == 0 {
statusCode = 200 statusCode = 200
} }
LogAccess(context, &startTime, statusCode) LogAccess(ctx, &startTime, statusCode)
timeDur := time.Since(startTime) timeDur := time.Since(startTime)
context.ResponseWriter.Elapsed = timeDur ctx.ResponseWriter.Elapsed = timeDur
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
pattern := "" pattern := ""
if routerInfo != nil { if routerInfo != nil {
@ -956,7 +953,7 @@ Admin:
if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
match := map[bool]string{true: "match", false: "nomatch"} match := map[bool]string{true: "match", false: "nomatch"}
devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
context.Input.IP(), ctx.Input.IP(),
logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
timeDur.String(), timeDur.String(),
match[findRouter], match[findRouter],
@ -969,11 +966,19 @@ Admin:
logs.Debug(devInfo) logs.Debug(devInfo)
} }
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
if context.Output.Status != 0 { if ctx.Output.Status != 0 {
context.ResponseWriter.WriteHeader(context.Output.Status) ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} }
} }
func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string {
urlPath := ctx.Request.URL.Path
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
}
return urlPath
}
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
// looping in reverse order for the case when both error and value are returned and error sets the response status code // looping in reverse order for the case when both error and value are returned and error sets the response status code
for i := len(results) - 1; i >= 0; i-- { for i := len(results) - 1; i >= 0; i-- {

View File

@ -16,12 +16,15 @@ package beego
import ( import (
"bytes" "bytes"
"github.com/astaxie/beego/pkg/testdata"
"github.com/elazarl/go-bindata-assetfs"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/elazarl/go-bindata-assetfs"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/pkg/testdata"
) )
var header = `{{define "header"}} var header = `{{define "header"}}
@ -46,7 +49,9 @@ var block = `{{define "block"}}
{{end}}` {{end}}`
func TestTemplate(t *testing.T) { func TestTemplate(t *testing.T) {
dir := "_beeTmp" wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate")
files := []string{ files := []string{
"header.tpl", "header.tpl",
"index.tpl", "index.tpl",
@ -56,7 +61,8 @@ func TestTemplate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
for k, name := range files { for k, name := range files {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
assert.Nil(t, dirErr)
if f, err := os.Create(filepath.Join(dir, name)); err != nil { if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
@ -107,7 +113,9 @@ var user = `<!DOCTYPE html>
` `
func TestRelativeTemplate(t *testing.T) { func TestRelativeTemplate(t *testing.T) {
dir := "_beeTmp" wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp")
//Just add dir to known viewPaths //Just add dir to known viewPaths
if err := AddViewPath(dir); err != nil { if err := AddViewPath(dir); err != nil {
@ -218,7 +226,10 @@ var output = `<!DOCTYPE html>
` `
func TestTemplateLayout(t *testing.T) { func TestTemplateLayout(t *testing.T) {
dir := "_beeTmp" wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout")
files := []string{ files := []string{
"add.tpl", "add.tpl",
"layout_blog.tpl", "layout_blog.tpl",
@ -226,17 +237,22 @@ func TestTemplateLayout(t *testing.T) {
if err := os.MkdirAll(dir, 0777); err != nil { if err := os.MkdirAll(dir, 0777); err != nil {
t.Fatal(err) t.Fatal(err)
} }
for k, name := range files { for k, name := range files {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
assert.Nil(t, dirErr)
if f, err := os.Create(filepath.Join(dir, name)); err != nil { if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
if k == 0 { if k == 0 {
f.WriteString(add) _, writeErr := f.WriteString(add)
assert.Nil(t, writeErr)
} else if k == 1 { } else if k == 1 {
f.WriteString(layoutBlog) _, writeErr := f.WriteString(layoutBlog)
assert.Nil(t, writeErr)
} }
f.Close() clErr := f.Close()
assert.Nil(t, clErr)
} }
} }
if err := AddViewPath(dir); err != nil { if err := AddViewPath(dir); err != nil {
@ -247,6 +263,7 @@ func TestTemplateLayout(t *testing.T) {
t.Fatalf("should be 2 but got %v", len(beeTemplates)) t.Fatalf("should be 2 but got %v", len(beeTemplates))
} }
out := bytes.NewBufferString("") out := bytes.NewBufferString("")
if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }