1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 07:51:30 +00:00

Merge remote-tracking branch 'upstream/develop' into develop

This commit is contained in:
nkbai 2015-12-14 15:09:06 +08:00
commit d99c62df1f
6 changed files with 150 additions and 141 deletions

View File

@ -35,16 +35,32 @@ import (
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
// NewContext return the Context with Input and Output
func NewContext() *Context {
return &Context{
Input: NewInput(),
Output: NewOutput(),
}
}
// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. // Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily. // BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct { type Context struct {
Input *BeegoInput Input *BeegoInput
Output *BeegoOutput Output *BeegoOutput
Request *http.Request Request *http.Request
ResponseWriter http.ResponseWriter ResponseWriter *Response
_xsrfToken string _xsrfToken string
} }
// Reset init Context, BeegoInput and BeegoOutput
func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.Request = r
ctx.ResponseWriter = &Response{rw, false, 0}
ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx)
}
// Redirect does redirection to localurl with http header status code. // Redirect does redirection to localurl with http header status code.
// It sends http response header directly. // It sends http response header directly.
func (ctx *Context) Redirect(status int, localurl string) { func (ctx *Context) Redirect(status int, localurl string) {
@ -148,3 +164,27 @@ func (ctx *Context) CheckXSRFCookie() bool {
} }
return true return true
} }
//Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type Response struct {
http.ResponseWriter
Started bool
Status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *Response) Write(p []byte) (int, error) {
w.Started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *Response) WriteHeader(code int) {
w.Status = code
w.Started = true
w.ResponseWriter.WriteHeader(code)
}

View File

@ -18,7 +18,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"regexp" "regexp"
@ -39,37 +38,43 @@ var (
// BeegoInput operates the http request header, data, cookie and body. // BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session. // it also contains router params and current session.
type BeegoInput struct { type BeegoInput struct {
Context *Context
CruSession session.Store CruSession session.Store
Params map[string]string Params map[string]string
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request
RequestBody []byte RequestBody []byte
RunController reflect.Type
RunMethod string
} }
// NewInput return BeegoInput generated by http.Request. // NewInput return BeegoInput generated by Context.
func NewInput(req *http.Request) *BeegoInput { func NewInput() *BeegoInput {
return &BeegoInput{ return &BeegoInput{
Params: make(map[string]string), Params: make(map[string]string),
Data: make(map[interface{}]interface{}), Data: make(map[interface{}]interface{}),
Request: req,
} }
} }
// Reset init the BeegoInput
func (input *BeegoInput) Reset(ctx *Context) {
input.Context = ctx
input.CruSession = nil
input.Params = make(map[string]string)
input.Data = make(map[interface{}]interface{})
input.RequestBody = []byte{}
}
// Protocol returns request protocol name, such as HTTP/1.1 . // Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string { func (input *BeegoInput) Protocol() string {
return input.Request.Proto return input.Context.Request.Proto
} }
// URI returns full request url with query string, fragment. // URI returns full request url with query string, fragment.
func (input *BeegoInput) URI() string { func (input *BeegoInput) URI() string {
return input.Request.RequestURI return input.Context.Request.RequestURI
} }
// URL returns request url path (without query string, fragment). // URL returns request url path (without query string, fragment).
func (input *BeegoInput) URL() string { func (input *BeegoInput) URL() string {
return input.Request.URL.Path return input.Context.Request.URL.Path
} }
// Site returns base site url as scheme://domain type. // Site returns base site url as scheme://domain type.
@ -79,10 +84,10 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https". // Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string { func (input *BeegoInput) Scheme() string {
if input.Request.URL.Scheme != "" { if input.Context.Request.URL.Scheme != "" {
return input.Request.URL.Scheme return input.Context.Request.URL.Scheme
} }
if input.Request.TLS == nil { if input.Context.Request.TLS == nil {
return "http" return "http"
} }
return "https" return "https"
@ -97,19 +102,19 @@ func (input *BeegoInput) Domain() string {
// Host returns host name. // Host returns host name.
// if no host info in request, return localhost. // if no host info in request, return localhost.
func (input *BeegoInput) Host() string { func (input *BeegoInput) Host() string {
if input.Request.Host != "" { if input.Context.Request.Host != "" {
hostParts := strings.Split(input.Request.Host, ":") hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 { if len(hostParts) > 0 {
return hostParts[0] return hostParts[0]
} }
return input.Request.Host return input.Context.Request.Host
} }
return "localhost" return "localhost"
} }
// Method returns http request method. // Method returns http request method.
func (input *BeegoInput) Method() string { func (input *BeegoInput) Method() string {
return input.Request.Method return input.Context.Request.Method
} }
// Is returns boolean of this request is on given method, such as Is("POST"). // Is returns boolean of this request is on given method, such as Is("POST").
@ -196,7 +201,7 @@ func (input *BeegoInput) IP() string {
rip := strings.Split(ips[0], ":") rip := strings.Split(ips[0], ":")
return rip[0] return rip[0]
} }
ip := strings.Split(input.Request.RemoteAddr, ":") ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 { if len(ip) > 0 {
if ip[0] != "[" { if ip[0] != "[" {
return ip[0] return ip[0]
@ -236,7 +241,7 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port. // Port returns request client port.
// when error or empty, return 80. // when error or empty, return 80.
func (input *BeegoInput) Port() int { func (input *BeegoInput) Port() int {
parts := strings.Split(input.Request.Host, ":") parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 { if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1]) port, _ := strconv.Atoi(parts[1])
return port return port
@ -262,22 +267,22 @@ func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" { if val := input.Param(key); val != "" {
return val return val
} }
if input.Request.Form == nil { if input.Context.Request.Form == nil {
input.Request.ParseForm() input.Context.Request.ParseForm()
} }
return input.Request.Form.Get(key) return input.Context.Request.Form.Get(key)
} }
// Header returns request header item string by a given string. // Header returns request header item string by a given string.
// if non-existed, return empty string. // if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string { func (input *BeegoInput) Header(key string) string {
return input.Request.Header.Get(key) return input.Context.Request.Header.Get(key)
} }
// Cookie returns request cookie item string by a given key. // Cookie returns request cookie item string by a given key.
// if non-existed, return empty string. // if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string { func (input *BeegoInput) Cookie(key string) string {
ck, err := input.Request.Cookie(key) ck, err := input.Context.Request.Cookie(key)
if err != nil { if err != nil {
return "" return ""
} }
@ -292,10 +297,10 @@ func (input *BeegoInput) Session(key interface{}) interface{} {
// CopyBody returns the raw request body data as bytes. // CopyBody returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody() []byte { func (input *BeegoInput) CopyBody() []byte {
requestbody, _ := ioutil.ReadAll(input.Request.Body) requestbody, _ := ioutil.ReadAll(input.Context.Request.Body)
input.Request.Body.Close() input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody) bf := bytes.NewBuffer(requestbody)
input.Request.Body = ioutil.NopCloser(bf) input.Context.Request.Body = ioutil.NopCloser(bf)
input.RequestBody = requestbody input.RequestBody = requestbody
return requestbody return requestbody
} }
@ -318,10 +323,10 @@ func (input *BeegoInput) SetData(key, val interface{}) {
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type. // Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
if err := input.Request.ParseMultipartForm(maxMemory); err != nil { if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error()) return errors.New("Error parsing request body:" + err.Error())
} }
} else if err := input.Request.ParseForm(); err != nil { } else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error()) return errors.New("Error parsing request body:" + err.Error())
} }
return nil return nil
@ -386,13 +391,13 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
} }
rv = input.bindBool(val, typ) rv = input.bindBool(val, typ)
case reflect.Slice: case reflect.Slice:
rv = input.bindSlice(&input.Request.Form, key, typ) rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct: case reflect.Struct:
rv = input.bindStruct(&input.Request.Form, key, typ) rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr: case reflect.Ptr:
rv = input.bindPoint(key, typ) rv = input.bindPoint(key, typ)
case reflect.Map: case reflect.Map:
rv = input.bindMap(&input.Request.Form, key, typ) rv = input.bindMap(&input.Context.Request.Form, key, typ)
} }
return rv return rv
} }

View File

@ -43,6 +43,12 @@ func NewOutput() *BeegoOutput {
return &BeegoOutput{} return &BeegoOutput{}
} }
// Reset init BeegoOutput
func (output *BeegoOutput) Reset(ctx *Context) {
output.Context = ctx
output.Status = 0
}
// Header sets response header item string via given key. // Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) { func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val) output.Context.ResponseWriter.Header().Set(key, val)
@ -55,7 +61,7 @@ func (output *BeegoOutput) Body(content []byte) {
var encoding string var encoding string
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
if output.EnableGzip { if output.EnableGzip {
encoding = ParseEncoding(output.Context.Input.Request) encoding = ParseEncoding(output.Context.Request)
} }
if b, n, _ := WriteBody(encoding, buf, content); b { if b, n, _ := WriteBody(encoding, buf, content); b {
output.Header("Content-Encoding", n) output.Header("Content-Encoding", n)

View File

@ -21,13 +21,7 @@ import (
) )
// GlobalDocAPI store the swagger api documents // GlobalDocAPI store the swagger api documents
var GlobalDocAPI map[string]interface{} var GlobalDocAPI = make(map[string]interface{})
func init() {
if BConfig.WebConfig.EnableDocs {
GlobalDocAPI = make(map[string]interface{})
}
}
func serverDocs(ctx *context.Context) { func serverDocs(ctx *context.Context) {
var obj interface{} var obj interface{}

134
router.go
View File

@ -24,6 +24,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
beecontext "github.com/astaxie/beego/context" beecontext "github.com/astaxie/beego/context"
@ -83,7 +84,7 @@ type logFilter struct {
} }
func (l *logFilter) Filter(ctx *beecontext.Context) bool { func (l *logFilter) Filter(ctx *beecontext.Context) bool {
requestPath := path.Clean(ctx.Input.Request.URL.Path) requestPath := path.Clean(ctx.Request.URL.Path)
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
return true return true
} }
@ -114,14 +115,19 @@ type ControllerRegister struct {
routers map[string]*Tree routers map[string]*Tree
enableFilter bool enableFilter bool
filters map[int][]*FilterRouter filters map[int][]*FilterRouter
pool sync.Pool
} }
// NewControllerRegister returns a new ControllerRegister. // NewControllerRegister returns a new ControllerRegister.
func NewControllerRegister() *ControllerRegister { func NewControllerRegister() *ControllerRegister {
return &ControllerRegister{ cr := &ControllerRegister{
routers: make(map[string]*Tree), routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter), filters: make(map[int][]*FilterRouter),
} }
cr.pool.New = func() interface{} {
return beecontext.NewContext()
}
return cr
} }
// Add controller handler and pattern rules to ControllerRegister. // Add controller handler and pattern rules to ControllerRegister.
@ -132,7 +138,7 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api/create",&RestController{},"post:CreateFood") // Add("/api/create",&RestController{},"post:CreateFood")
// Add("/api/update",&RestController{},"put:UpdateFood") // Add("/api/update",&RestController{},"put:UpdateFood")
// Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api/delete",&RestController{},"delete:DeleteFood")
// Add("/api",&RestController{},"get,post:ApiFunc") // Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
@ -570,44 +576,11 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return false, "" return false, ""
} }
// Implement http.Handler interface. func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) {
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
starttime := time.Now()
var runrouter reflect.Type
var findrouter bool
var runMethod string
var routerInfo *controllerInfo
w := &responseWriter{rw, false, 0}
if BConfig.RunMode == "dev" {
w.Header().Set("Server", BConfig.ServerName)
}
// init context
context := &beecontext.Context{
ResponseWriter: w,
Request: r,
Input: beecontext.NewInput(r),
Output: beecontext.NewOutput(),
}
context.Output.Context = context
context.Output.EnableGzip = BConfig.EnableGzip
defer p.recoverPanic(context)
var urlPath string
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path)
} else {
urlPath = r.URL.Path
}
// defined filter function
doFilter := func(pos int) (started bool) {
if p.enableFilter { if p.enableFilter {
if l, ok := p.filters[pos]; ok { if l, ok := p.filters[pos]; ok {
for _, filterR := range l { for _, filterR := range l {
if filterR.returnOnOutput && w.started { if filterR.returnOnOutput && context.ResponseWriter.Started {
return true return true
} }
if ok, params := filterR.ValidRouter(urlPath); ok { if ok, params := filterR.ValidRouter(urlPath); ok {
@ -619,28 +592,53 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
filterR.filterFunc(context) filterR.filterFunc(context)
} }
if filterR.returnOnOutput && w.started { if filterR.returnOnOutput && context.ResponseWriter.Started {
return true return true
} }
} }
} }
} }
return false return false
}
// Implement http.Handler interface.
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
starttime := time.Now()
var (
runrouter reflect.Type
findrouter bool
runMethod string
routerInfo *controllerInfo
)
context := p.pool.Get().(*beecontext.Context)
context.Reset(rw, r)
defer p.pool.Put(context)
defer p.recoverPanic(context)
if BConfig.RunMode == "dev" {
context.Output.Header("Server", BConfig.ServerName)
}
var urlPath string
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path)
} else {
urlPath = r.URL.Path
} }
// filter wrong httpmethod // filter wrong httpmethod
if _, ok := HTTPMETHOD[r.Method]; !ok { if _, ok := HTTPMETHOD[r.Method]; !ok {
http.Error(w, "Method Not Allowed", 405) http.Error(rw, "Method Not Allowed", 405)
goto Admin goto Admin
} }
// filter for static file // filter for static file
if doFilter(BeforeStatic) { if p.execFilter(context, BeforeStatic, urlPath) {
goto Admin goto Admin
} }
serverStaticRouter(context) serverStaticRouter(context)
if w.started { if context.ResponseWriter.Started {
findrouter = true findrouter = true
goto Admin goto Admin
} }
@ -648,14 +646,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init // session init
if BConfig.WebConfig.Session.SessionOn { if BConfig.WebConfig.Session.SessionOn {
var err error var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(w, r) context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil { if err != nil {
Error(err) Error(err)
exception("503", context) exception("503", context)
return return
} }
defer func() { defer func() {
context.Input.CruSession.SessionRelease(w) context.Input.CruSession.SessionRelease(rw)
}() }()
} }
@ -666,27 +664,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
} }
if doFilter(BeforeRouter) { if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin goto Admin
} }
if context.Input.RunController != nil && context.Input.RunMethod != "" {
findrouter = true
runMethod = context.Input.RunMethod
runrouter = context.Input.RunController
}
if !findrouter { if !findrouter {
httpMethod := r.Method httpMethod := r.Method
if httpMethod == "POST" && context.Input.Query("_method") == "PUT" { if httpMethod == "POST" && context.Input.Query("_method") == "PUT" {
httpMethod = "PUT" httpMethod = "PUT"
} }
if httpMethod == "POST" && context.Input.Query("_method") == "DELETE" { if httpMethod == "POST" && context.Input.Query("_method") == "DELETE" {
httpMethod = "DELETE" httpMethod = "DELETE"
} }
if t, ok := p.routers[httpMethod]; ok { if t, ok := p.routers[httpMethod]; ok {
runObject, p := t.Match(urlPath) runObject, p := t.Match(urlPath)
if r, ok := runObject.(*controllerInfo); ok { if r, ok := runObject.(*controllerInfo); ok {
@ -714,7 +703,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if findrouter { if findrouter {
//execute middleware filters //execute middleware filters
if doFilter(BeforeExec) { if p.execFilter(context, BeforeExec, urlPath) {
goto Admin goto Admin
} }
isRunnable := false isRunnable := false
@ -775,7 +764,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.URLMapping() execController.URLMapping()
if !w.started { if !context.ResponseWriter.Started {
//exec main logic //exec main logic
switch runMethod { switch runMethod {
case "GET": case "GET":
@ -801,7 +790,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//render template //render template
if !w.started && context.Output.Status == 0 { if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
panic(err) panic(err)
@ -815,12 +804,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//execute middleware filters //execute middleware filters
if doFilter(AfterExec) { if p.execFilter(context, AfterExec, urlPath) {
goto Admin goto Admin
} }
} }
doFilter(FinishRouter) p.execFilter(context, FinishRouter, urlPath)
Admin: Admin:
timeend := time.Since(starttime) timeend := time.Since(starttime)
@ -853,7 +842,7 @@ Admin:
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
if context.Output.Status != 0 { if context.Output.Status != 0 {
w.WriteHeader(context.Output.Status) context.ResponseWriter.WriteHeader(context.Output.Status)
} }
} }
@ -889,31 +878,6 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
} }
} }
//responseWriter is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type responseWriter struct {
http.ResponseWriter
started bool
status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *responseWriter) Write(p []byte) (int, error) {
w.started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *responseWriter) WriteHeader(code int) {
w.status = code
w.started = true
w.ResponseWriter.WriteHeader(code)
}
func tourl(params map[string]string) string { func tourl(params map[string]string) string {
if len(params) == 0 { if len(params) == 0 {
return "" return ""

View File

@ -144,7 +144,7 @@ func isStaticCompress(filePath string) bool {
// searchFile search the file by url path // searchFile search the file by url path
// if none the static file prefix matches ,return notStaticRequestErr // if none the static file prefix matches ,return notStaticRequestErr
func searchFile(ctx *context.Context) (string, os.FileInfo, error) { func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
requestPath := filepath.ToSlash(filepath.Clean(ctx.Input.Request.URL.Path)) requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path))
// special processing : favicon.ico/robots.txt can be in any static dir // special processing : favicon.ico/robots.txt can be in any static dir
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
file := path.Join(".", requestPath) file := path.Join(".", requestPath)