1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 13:30:56 +00:00

Merge pull request #2085 from danielscottt/reset-params

adds ability to reset params after a filter runs
This commit is contained in:
astaxie 2016-08-09 07:24:08 +08:00 committed by GitHub
commit 61c9387edd
4 changed files with 117 additions and 14 deletions

View File

@ -301,6 +301,14 @@ func (input *BeegoInput) SetParam(key, val string) {
input.pnames = append(input.pnames, key) input.pnames = append(input.pnames, key)
} }
// ResetParams clears any of the input's Params
// This function is used to clear parameters so they may be reset between filter
// passes.
func (input *BeegoInput) ResetParams() {
input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0]
}
// Query returns input data item string by a given string. // Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string { func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" { if val := input.Param(key); val != "" {

View File

@ -27,6 +27,7 @@ type FilterRouter struct {
tree *Tree tree *Tree
pattern string pattern string
returnOnOutput bool returnOnOutput bool
resetParams bool
} }
// ValidRouter checks if the current request is matched by this filter. // ValidRouter checks if the current request is matched by this filter.

View File

@ -406,20 +406,27 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
} }
// InsertFilter Add a FilterFunc with pattern rule and action constant. // InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) // params is for:
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
// 2. determining whether or not params need to be reset.
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := new(FilterRouter) mr := &FilterRouter{
mr.tree = NewTree() tree: NewTree(),
mr.pattern = pattern pattern: pattern,
mr.filterFunc = filter filterFunc: filter,
if !BConfig.RouterCaseSensitive { returnOnOutput: true,
pattern = strings.ToLower(pattern)
} }
if len(params) == 0 { if !BConfig.RouterCaseSensitive {
mr.returnOnOutput = true mr.pattern = strings.ToLower(pattern)
} else { }
paramsLen := len(params)
if paramsLen > 0 {
mr.returnOnOutput = params[0] mr.returnOnOutput = params[0]
} }
if paramsLen > 1 {
mr.resetParams = params[1]
}
mr.tree.AddRouter(pattern, true) mr.tree.AddRouter(pattern, true)
return p.insertFilterRouter(pos, mr) return p.insertFilterRouter(pos, mr)
} }
@ -581,12 +588,22 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
} }
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
var preFilterParams map[string]string
for _, filterR := range p.filters[pos] { for _, filterR := range p.filters[pos] {
if filterR.returnOnOutput && context.ResponseWriter.Started { if filterR.returnOnOutput && context.ResponseWriter.Started {
return true return true
} }
if filterR.resetParams {
preFilterParams = context.Input.Params()
}
if ok := filterR.ValidRouter(urlPath, context); ok { if ok := filterR.ValidRouter(urlPath, context); ok {
filterR.filterFunc(context) filterR.filterFunc(context)
if filterR.resetParams {
context.Input.ResetParams()
for k, v := range preFilterParams {
context.Input.SetParam(k, v)
}
}
} }
if filterR.returnOnOutput && context.ResponseWriter.Started { if filterR.returnOnOutput && context.ResponseWriter.Started {
return true return true
@ -810,7 +827,9 @@ Admin:
var devInfo string var devInfo string
statusCode := context.ResponseWriter.Status statusCode := context.ResponseWriter.Status
if statusCode == 0 { statusCode = 200 } if statusCode == 0 {
statusCode = 200
}
iswin := (runtime.GOOS == "windows") iswin := (runtime.GOOS == "windows")
statusColor := logs.ColorByStatus(iswin, statusCode) statusColor := logs.ColorByStatus(iswin, statusCode)
@ -819,9 +838,9 @@ Admin:
if findRouter { if findRouter {
if routerInfo != nil { if routerInfo != nil {
devInfo = fmt.Sprintf("|%s %3d %s|%13s|%8s|%s %s %-7s %-3s r:%s", statusColor, statusCode, devInfo = fmt.Sprintf("|%s %3d %s|%13s|%8s|%s %s %-7s %-3s r:%s", statusColor, statusCode,
resetColor, timeDur.String(), "match", methodColor, resetColor, r.Method, r.URL.Path, resetColor, timeDur.String(), "match", methodColor, resetColor, r.Method, r.URL.Path,
routerInfo.pattern) routerInfo.pattern)
} else { } else {
devInfo = fmt.Sprintf("|%s %3d %s|%13s|%8s|%s %s %-7s %-3s", statusColor, statusCode, resetColor, devInfo = fmt.Sprintf("|%s %3d %s|%13s|%8s|%s %s %-7s %-3s", statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, resetColor, r.Method, r.URL.Path) timeDur.String(), "match", methodColor, resetColor, r.Method, r.URL.Path)

View File

@ -420,6 +420,74 @@ func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request
return recorder, request return recorder, request
} }
// Expectation: A Filter with the correct configuration should be created given
// specific parameters.
func TestInsertFilter(t *testing.T) {
testName := "TestInsertFilter"
mux := NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {})
if !mux.filters[BeforeRouter][0].returnOnOutput {
t.Errorf(
"%s: passing no variadic params should set returnOnOutput to true",
testName)
}
if mux.filters[BeforeRouter][0].resetParams {
t.Errorf(
"%s: passing no variadic params should set resetParams to false",
testName)
}
mux = NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false)
if mux.filters[BeforeRouter][0].returnOnOutput {
t.Errorf(
"%s: passing false as 1st variadic param should set returnOnOutput to false",
testName)
}
mux = NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true)
if !mux.filters[BeforeRouter][0].resetParams {
t.Errorf(
"%s: passing true as 2nd variadic param should set resetParams to true",
testName)
}
}
// Expectation: the second variadic arg should cause the execution of the filter
// to preserve the parameters from before its execution.
func TestParamResetFilter(t *testing.T) {
testName := "TestParamResetFilter"
route := "/beego/*" // splat
path := "/beego/routes/routes"
mux := NewControllerRegister()
mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true)
mux.Get(route, beegoHandleResetParams)
rw, r := testRequest("GET", path)
mux.ServeHTTP(rw, r)
// The two functions, `beegoResetParams` and `beegoHandleResetParams` add
// a response header of `Splat`. The expectation here is that that Header
// value should match what the _request's_ router set, not the filter's.
headers := rw.HeaderMap
if len(headers["Splat"]) != 1 {
t.Errorf(
"%s: There was an error in the test. Splat param not set in Header",
testName)
}
if headers["Splat"][0] != "routes/routes" {
t.Errorf(
"%s: expected `:splat` param to be [routes/routes] but it was [%s]",
testName, headers["Splat"][0])
}
}
// Execution point: BeforeRouter // Execution point: BeforeRouter
// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle // expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle
func TestFilterBeforeRouter(t *testing.T) { func TestFilterBeforeRouter(t *testing.T) {
@ -612,3 +680,10 @@ func beegoFinishRouter1(ctx *context.Context) {
func beegoFinishRouter2(ctx *context.Context) { func beegoFinishRouter2(ctx *context.Context) {
ctx.WriteString("|FinishRouter2") ctx.WriteString("|FinishRouter2")
} }
func beegoResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
}
func beegoHandleResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
}