1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 17:10:54 +00:00
This commit is contained in:
astaxie 2015-12-11 00:20:17 +08:00
parent f70f338025
commit 80bc372f17
2 changed files with 69 additions and 69 deletions

View File

@ -49,14 +49,14 @@ type Context struct {
Input *BeegoInput Input *BeegoInput
Output *BeegoOutput Output *BeegoOutput
Request *http.Request Request *http.Request
ResponseWriter http.ResponseWriter ResponseWriter *Response
_xsrfToken string _xsrfToken string
} }
// Reset init Context, BeegoInput and BeegoOutput // Reset init Context, BeegoInput and BeegoOutput
func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.Request = r ctx.Request = r
ctx.ResponseWriter = rw ctx.ResponseWriter = &Response{rw, false, 0}
ctx.Input.Reset(ctx) ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx) ctx.Output.Reset(ctx)
} }
@ -164,3 +164,27 @@ func (ctx *Context) CheckXSRFCookie() bool {
} }
return true return true
} }
//Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type Response struct {
http.ResponseWriter
Started bool
Status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *Response) Write(p []byte) (int, error) {
w.Started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *Response) WriteHeader(code int) {
w.Status = code
w.Started = true
w.ResponseWriter.WriteHeader(code)
}

110
router.go
View File

@ -576,6 +576,31 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return false, "" return false, ""
} }
func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) {
if p.enableFilter {
if l, ok := p.filters[pos]; ok {
for _, filterR := range l {
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
}
if ok, params := filterR.ValidRouter(urlPath); ok {
for k, v := range params {
if context.Input.Params == nil {
context.Input.Params = make(map[string]string)
}
context.Input.Params[k] = v
}
filterR.filterFunc(context)
}
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
}
}
}
}
return false
}
// Implement http.Handler interface. // Implement http.Handler interface.
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
starttime := time.Now() starttime := time.Now()
@ -584,61 +609,36 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
findrouter bool findrouter bool
runMethod string runMethod string
routerInfo *controllerInfo routerInfo *controllerInfo
w = &responseWriter{rw, false, 0}
) )
if BConfig.RunMode == "dev" {
w.Header().Set("Server", BConfig.ServerName)
}
context := p.pool.Get().(*beecontext.Context) context := p.pool.Get().(*beecontext.Context)
context.Reset(w, r) context.Reset(rw, r)
defer p.pool.Put(context)
defer p.recoverPanic(context) defer p.recoverPanic(context)
if BConfig.RunMode == "dev" {
context.Output.Header("Server", BConfig.ServerName)
}
var urlPath string var urlPath string
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path) urlPath = strings.ToLower(r.URL.Path)
} else { } else {
urlPath = r.URL.Path urlPath = r.URL.Path
} }
// defined filter function
doFilter := func(pos int) (started bool) {
if p.enableFilter {
if l, ok := p.filters[pos]; ok {
for _, filterR := range l {
if filterR.returnOnOutput && w.started {
return true
}
if ok, params := filterR.ValidRouter(urlPath); ok {
for k, v := range params {
if context.Input.Params == nil {
context.Input.Params = make(map[string]string)
}
context.Input.Params[k] = v
}
filterR.filterFunc(context)
}
if filterR.returnOnOutput && w.started {
return true
}
}
}
}
return false
}
// filter wrong httpmethod // filter wrong httpmethod
if _, ok := HTTPMETHOD[r.Method]; !ok { if _, ok := HTTPMETHOD[r.Method]; !ok {
http.Error(w, "Method Not Allowed", 405) http.Error(rw, "Method Not Allowed", 405)
goto Admin goto Admin
} }
// filter for static file // filter for static file
if doFilter(BeforeStatic) { if p.execFilter(context, BeforeStatic, urlPath) {
goto Admin goto Admin
} }
serverStaticRouter(context) serverStaticRouter(context)
if w.started { if context.ResponseWriter.Started {
findrouter = true findrouter = true
goto Admin goto Admin
} }
@ -646,14 +646,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init // session init
if BConfig.WebConfig.Session.SessionOn { if BConfig.WebConfig.Session.SessionOn {
var err error var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(w, r) context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil { if err != nil {
Error(err) Error(err)
exception("503", context) exception("503", context)
return return
} }
defer func() { defer func() {
context.Input.CruSession.SessionRelease(w) context.Input.CruSession.SessionRelease(rw)
}() }()
} }
@ -664,7 +664,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
} }
if doFilter(BeforeRouter) { if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin goto Admin
} }
@ -703,7 +703,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if findrouter { if findrouter {
//execute middleware filters //execute middleware filters
if doFilter(BeforeExec) { if p.execFilter(context, BeforeExec, urlPath) {
goto Admin goto Admin
} }
isRunnable := false isRunnable := false
@ -764,7 +764,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.URLMapping() execController.URLMapping()
if !w.started { if !context.ResponseWriter.Started {
//exec main logic //exec main logic
switch runMethod { switch runMethod {
case "GET": case "GET":
@ -790,7 +790,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//render template //render template
if !w.started && context.Output.Status == 0 { if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
panic(err) panic(err)
@ -804,12 +804,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//execute middleware filters //execute middleware filters
if doFilter(AfterExec) { if p.execFilter(context, AfterExec, urlPath) {
goto Admin goto Admin
} }
} }
doFilter(FinishRouter) p.execFilter(context, FinishRouter, urlPath)
Admin: Admin:
timeend := time.Since(starttime) timeend := time.Since(starttime)
@ -842,7 +842,7 @@ Admin:
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
if context.Output.Status != 0 { if context.Output.Status != 0 {
w.WriteHeader(context.Output.Status) context.ResponseWriter.WriteHeader(context.Output.Status)
} }
} }
@ -878,30 +878,6 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
} }
} }
//responseWriter is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type responseWriter struct {
http.ResponseWriter
started bool
status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *responseWriter) Write(p []byte) (int, error) {
w.started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *responseWriter) WriteHeader(code int) {
w.status = code
w.started = true
w.ResponseWriter.WriteHeader(code)
}
func tourl(params map[string]string) string { func tourl(params map[string]string) string {
if len(params) == 0 { if len(params) == 0 {
return "" return ""