From 50f3bd5835cd0c1c2d1ffc879005f7813895acb5 Mon Sep 17 00:00:00 2001 From: astaxie Date: Mon, 12 Aug 2013 00:14:42 +0800 Subject: [PATCH] add filter after --- beego.go | 30 +++++++++++++++++++++++++ router.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/beego.go b/beego.go index f6bf521c..d0e34612 100644 --- a/beego.go +++ b/beego.go @@ -161,6 +161,21 @@ func (app *App) FilterPrefixPath(path string, filter http.HandlerFunc) *App { return app } +func (app *App) FilterAfter(filter http.HandlerFunc) *App { + app.Handlers.FilterAfter(filter) + return app +} + +func (app *App) FilterParamAfter(param string, filter http.HandlerFunc) *App { + app.Handlers.FilterParamAfter(param, filter) + return app +} + +func (app *App) FilterPrefixPathAfter(path string, filter http.HandlerFunc) *App { + app.Handlers.FilterPrefixPathAfter(path, filter) + return app +} + func (app *App) SetViewsPath(path string) *App { ViewsPath = path return app @@ -245,6 +260,21 @@ func FilterPrefixPath(path string, filter http.HandlerFunc) *App { return BeeApp } +func FilterAfter(filter http.HandlerFunc) *App { + BeeApp.FilterAfter(filter) + return BeeApp +} + +func FilterParamAfter(param string, filter http.HandlerFunc) *App { + BeeApp.FilterParamAfter(param, filter) + return BeeApp +} + +func FilterPrefixPathAfter(path string, filter http.HandlerFunc) *App { + BeeApp.FilterPrefixPathAfter(path, filter) + return BeeApp +} + func Run() { if AppConfigPath != path.Join(AppPath, "conf", "app.conf") { err := ParseConfig() diff --git a/router.go b/router.go index ed3287f9..157f0618 100644 --- a/router.go +++ b/router.go @@ -37,6 +37,7 @@ type ControllerRegistor struct { fixrouters []*controllerInfo enableFilter bool filters []http.HandlerFunc + enableAfter bool afterFilters []http.HandlerFunc enableUser bool userHandlers map[string]*userHandler @@ -257,6 +258,35 @@ func (p *ControllerRegistor) FilterPrefixPath(path string, filter http.HandlerFu }) } +// Filter adds the middleware after filter. +func (p *ControllerRegistor) FilterAfter(filter http.HandlerFunc) { + p.enableAfter = true + p.afterFilters = append(p.afterFilters, filter) +} + +// FilterParam adds the middleware filter if the REST URL parameter exists. +func (p *ControllerRegistor) FilterParamAfter(param string, filter http.HandlerFunc) { + if !strings.HasPrefix(param, ":") { + param = ":" + param + } + + p.FilterAfter(func(w http.ResponseWriter, r *http.Request) { + p := r.URL.Query().Get(param) + if len(p) > 0 { + filter(w, r) + } + }) +} + +// FilterPrefixPath adds the middleware filter if the prefix path exists. +func (p *ControllerRegistor) FilterPrefixPathAfter(path string, filter http.HandlerFunc) { + p.FilterAfter(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, path) { + filter(w, r) + } + }) +} + // AutoRoute func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) { defer func() { @@ -440,10 +470,12 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if runrouter != nil { //execute middleware filters - for _, filter := range p.filters { - filter(w, r) - if w.started { - return + if p.enableFilter { + for _, filter := range p.filters { + filter(w, r) + if w.started { + return + } } } @@ -587,6 +619,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } method = vc.MethodByName("Finish") method.Call(in) + //execute middleware filters + if p.enableAfter { + for _, filter := range p.afterFilters { + filter(w, r) + if w.started { + return + } + } + } method = vc.MethodByName("Destructor") method.Call(in) } @@ -608,6 +649,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/") { for mName, controllerType := range methodmap { if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/"+strings.ToLower(mName)) { + //execute middleware filters + if p.enableFilter { + for _, filter := range p.filters { + filter(w, r) + if w.started { + return + } + } + } //parse params otherurl := requestPath[len("/"+cName+"/"+strings.ToLower(mName)):] if len(otherurl) > 1 { @@ -651,6 +701,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } method = vc.MethodByName("Finish") method.Call(in) + //execute middleware filters + if p.enableAfter { + for _, filter := range p.afterFilters { + filter(w, r) + if w.started { + return + } + } + } method = vc.MethodByName("Destructor") method.Call(in) // set find