diff --git a/context/context.go b/context/context.go index a7df6f8d..f3dffdbd 100644 --- a/context/context.go +++ b/context/context.go @@ -49,14 +49,14 @@ type Context struct { Input *BeegoInput Output *BeegoOutput Request *http.Request - ResponseWriter http.ResponseWriter + ResponseWriter *Response _xsrfToken string } // Reset init Context, BeegoInput and BeegoOutput func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx.Request = r - ctx.ResponseWriter = rw + ctx.ResponseWriter = &Response{rw, false, 0} ctx.Input.Reset(ctx) ctx.Output.Reset(ctx) } @@ -164,3 +164,27 @@ func (ctx *Context) CheckXSRFCookie() bool { } return true } + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (w *Response) Write(p []byte) (int, error) { + w.Started = true + return w.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (w *Response) WriteHeader(code int) { + w.Status = code + w.Started = true + w.ResponseWriter.WriteHeader(code) +} diff --git a/router.go b/router.go index a575f9ed..01dae8aa 100644 --- a/router.go +++ b/router.go @@ -576,6 +576,31 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin return false, "" } +func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { + if p.enableFilter { + if l, ok := p.filters[pos]; ok { + for _, filterR := range l { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if ok, params := filterR.ValidRouter(urlPath); ok { + for k, v := range params { + if context.Input.Params == nil { + context.Input.Params = make(map[string]string) + } + context.Input.Params[k] = v + } + filterR.filterFunc(context) + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + } + } + } + return false +} + // Implement http.Handler interface. func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { starttime := time.Now() @@ -584,61 +609,36 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) findrouter bool runMethod string routerInfo *controllerInfo - w = &responseWriter{rw, false, 0} ) - if BConfig.RunMode == "dev" { - w.Header().Set("Server", BConfig.ServerName) - } context := p.pool.Get().(*beecontext.Context) - context.Reset(w, r) - + context.Reset(rw, r) + defer p.pool.Put(context) defer p.recoverPanic(context) + if BConfig.RunMode == "dev" { + context.Output.Header("Server", BConfig.ServerName) + } + var urlPath string if !BConfig.RouterCaseSensitive { urlPath = strings.ToLower(r.URL.Path) } else { urlPath = r.URL.Path } - // defined filter function - doFilter := func(pos int) (started bool) { - if p.enableFilter { - if l, ok := p.filters[pos]; ok { - for _, filterR := range l { - if filterR.returnOnOutput && w.started { - return true - } - if ok, params := filterR.ValidRouter(urlPath); ok { - for k, v := range params { - if context.Input.Params == nil { - context.Input.Params = make(map[string]string) - } - context.Input.Params[k] = v - } - filterR.filterFunc(context) - } - if filterR.returnOnOutput && w.started { - return true - } - } - } - } - return false - } // filter wrong httpmethod if _, ok := HTTPMETHOD[r.Method]; !ok { - http.Error(w, "Method Not Allowed", 405) + http.Error(rw, "Method Not Allowed", 405) goto Admin } // filter for static file - if doFilter(BeforeStatic) { + if p.execFilter(context, BeforeStatic, urlPath) { goto Admin } serverStaticRouter(context) - if w.started { + if context.ResponseWriter.Started { findrouter = true goto Admin } @@ -646,14 +646,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) // session init if BConfig.WebConfig.Session.SessionOn { var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(w, r) + context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { Error(err) exception("503", context) return } defer func() { - context.Input.CruSession.SessionRelease(w) + context.Input.CruSession.SessionRelease(rw) }() } @@ -664,7 +664,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) } - if doFilter(BeforeRouter) { + if p.execFilter(context, BeforeRouter, urlPath) { goto Admin } @@ -703,7 +703,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if findrouter { //execute middleware filters - if doFilter(BeforeExec) { + if p.execFilter(context, BeforeExec, urlPath) { goto Admin } isRunnable := false @@ -764,7 +764,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) execController.URLMapping() - if !w.started { + if !context.ResponseWriter.Started { //exec main logic switch runMethod { case "GET": @@ -790,7 +790,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } //render template - if !w.started && context.Output.Status == 0 { + if !context.ResponseWriter.Started && context.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { panic(err) @@ -804,12 +804,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } //execute middleware filters - if doFilter(AfterExec) { + if p.execFilter(context, AfterExec, urlPath) { goto Admin } } - doFilter(FinishRouter) + p.execFilter(context, FinishRouter, urlPath) Admin: timeend := time.Since(starttime) @@ -842,7 +842,7 @@ Admin: // Call WriteHeader if status code has been set changed if context.Output.Status != 0 { - w.WriteHeader(context.Output.Status) + context.ResponseWriter.WriteHeader(context.Output.Status) } } @@ -878,30 +878,6 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) { } } -//responseWriter is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler -type responseWriter struct { - http.ResponseWriter - started bool - status int -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (w *responseWriter) Write(p []byte) (int, error) { - w.started = true - return w.ResponseWriter.Write(p) -} - -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. -func (w *responseWriter) WriteHeader(code int) { - w.status = code - w.started = true - w.ResponseWriter.WriteHeader(code) -} - func tourl(params map[string]string) string { if len(params) == 0 { return ""