diff --git a/context/input.go b/context/input.go index c20c8b14..6c60e72c 100644 --- a/context/input.go +++ b/context/input.go @@ -14,28 +14,28 @@ type BeegoInput struct { CruSession session.SessionStore Params map[string]string Data map[interface{}]interface{} - req *http.Request + Request *http.Request RequestBody []byte } func NewInput(req *http.Request) *BeegoInput { return &BeegoInput{ - Params: make(map[string]string), - Data: make(map[interface{}]interface{}), - req: req, + Params: make(map[string]string), + Data: make(map[interface{}]interface{}), + Request: req, } } func (input *BeegoInput) Protocol() string { - return input.req.Proto + return input.Request.Proto } func (input *BeegoInput) Uri() string { - return input.req.RequestURI + return input.Request.RequestURI } func (input *BeegoInput) Url() string { - return input.req.URL.String() + return input.Request.URL.String() } func (input *BeegoInput) Site() string { @@ -43,9 +43,9 @@ func (input *BeegoInput) Site() string { } func (input *BeegoInput) Scheme() string { - if input.req.URL.Scheme != "" { - return input.req.URL.Scheme - } else if input.req.TLS == nil { + if input.Request.URL.Scheme != "" { + return input.Request.URL.Scheme + } else if input.Request.TLS == nil { return "http" } else { return "https" @@ -57,18 +57,18 @@ func (input *BeegoInput) Domain() string { } func (input *BeegoInput) Host() string { - if input.req.Host != "" { - hostParts := strings.Split(input.req.Host, ":") + if input.Request.Host != "" { + hostParts := strings.Split(input.Request.Host, ":") if len(hostParts) > 0 { return hostParts[0] } - return input.req.Host + return input.Request.Host } return "localhost" } func (input *BeegoInput) Method() string { - return input.req.Method + return input.Request.Method } func (input *BeegoInput) Is(method string) bool { @@ -88,7 +88,7 @@ func (input *BeegoInput) IsWebsocket() bool { } func (input *BeegoInput) IsUpload() bool { - return input.req.MultipartForm != nil + return input.Request.MultipartForm != nil } func (input *BeegoInput) IP() string { @@ -96,7 +96,7 @@ func (input *BeegoInput) IP() string { if len(ips) > 0 && ips[0] != "" { return ips[0] } - ip := strings.Split(input.req.RemoteAddr, ":") + ip := strings.Split(input.Request.RemoteAddr, ":") if len(ip) > 0 { return ip[0] } @@ -120,7 +120,7 @@ func (input *BeegoInput) SubDomains() string { } func (input *BeegoInput) Port() int { - parts := strings.Split(input.req.Host, ":") + parts := strings.Split(input.Request.Host, ":") if len(parts) == 2 { port, _ := strconv.Atoi(parts[1]) return port @@ -140,16 +140,16 @@ func (input *BeegoInput) Param(key string) string { } func (input *BeegoInput) Query(key string) string { - input.req.ParseForm() - return input.req.Form.Get(key) + input.Request.ParseForm() + return input.Request.Form.Get(key) } func (input *BeegoInput) Header(key string) string { - return input.req.Header.Get(key) + return input.Request.Header.Get(key) } func (input *BeegoInput) Cookie(key string) string { - ck, err := input.req.Cookie(key) + ck, err := input.Request.Cookie(key) if err != nil { return "" } @@ -161,10 +161,10 @@ func (input *BeegoInput) Session(key interface{}) interface{} { } func (input *BeegoInput) Body() []byte { - requestbody, _ := ioutil.ReadAll(input.req.Body) - input.req.Body.Close() + requestbody, _ := ioutil.ReadAll(input.Request.Body) + input.Request.Body.Close() bf := bytes.NewBuffer(requestbody) - input.req.Body = ioutil.NopCloser(bf) + input.Request.Body = ioutil.NopCloser(bf) input.RequestBody = requestbody return requestbody } diff --git a/context/output.go b/context/output.go index ce40cf89..60e04ebe 100644 --- a/context/output.go +++ b/context/output.go @@ -21,21 +21,18 @@ type BeegoOutput struct { Context *Context Status int EnableGzip bool - res http.ResponseWriter } -func NewOutput(res http.ResponseWriter) *BeegoOutput { - return &BeegoOutput{ - res: res, - } +func NewOutput() *BeegoOutput { + return &BeegoOutput{} } func (output *BeegoOutput) Header(key, val string) { - output.res.Header().Set(key, val) + output.Context.ResponseWriter.Header().Set(key, val) } func (output *BeegoOutput) Body(content []byte) { - output_writer := output.res.(io.Writer) + output_writer := output.Context.ResponseWriter.(io.Writer) if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" { splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1) encodings := make([]string, len(splitted)) @@ -46,12 +43,12 @@ func (output *BeegoOutput) Body(content []byte) { for _, val := range encodings { if val == "gzip" { output.Header("Content-Encoding", "gzip") - output_writer, _ = gzip.NewWriterLevel(output.res, gzip.BestSpeed) + output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed) break } else if val == "deflate" { output.Header("Content-Encoding", "deflate") - output_writer, _ = flate.NewWriter(output.res, flate.BestSpeed) + output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed) break } } @@ -104,7 +101,7 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface if len(others) > 4 { fmt.Fprintf(&b, "; HttpOnly") } - output.res.Header().Add("Set-Cookie", b.String()) + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) } var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") @@ -129,7 +126,7 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e content, err = json.Marshal(data) } if err != nil { - http.Error(output.res, err.Error(), http.StatusInternalServerError) + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) return err } if coding { @@ -149,7 +146,7 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error { content, err = json.Marshal(data) } if err != nil { - http.Error(output.res, err.Error(), http.StatusInternalServerError) + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) return err } callback := output.Context.Input.Query("callback") @@ -174,7 +171,7 @@ func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error { content, err = xml.Marshal(data) } if err != nil { - http.Error(output.res, err.Error(), http.StatusInternalServerError) + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) return err } output.Body(content) @@ -189,7 +186,7 @@ func (output *BeegoOutput) Download(file string) { output.Header("Expires", "0") output.Header("Cache-Control", "must-revalidate") output.Header("Pragma", "public") - http.ServeFile(output.res, output.Context.Request, file) + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) } func (output *BeegoOutput) ContentType(ext string) { @@ -203,7 +200,7 @@ func (output *BeegoOutput) ContentType(ext string) { } func (output *BeegoOutput) SetStatus(status int) { - output.res.WriteHeader(status) + output.Context.ResponseWriter.WriteHeader(status) output.Status = status } diff --git a/router.go b/router.go index 248fdaaa..cd2583bc 100644 --- a/router.go +++ b/router.go @@ -28,7 +28,6 @@ const ( var ( HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"} - errorType = reflect.TypeOf((*error)(nil)).Elem() ) type controllerInfo struct { @@ -41,19 +40,21 @@ type controllerInfo struct { } type ControllerRegistor struct { - routers []*controllerInfo - fixrouters []*controllerInfo - enableFilter bool - filters map[int][]*FilterRouter - enableAuto bool - autoRouter map[string]map[string]reflect.Type //key:controller key:method value:reflect.type + routers []*controllerInfo + fixrouters []*controllerInfo + enableFilter bool + filters map[int][]*FilterRouter + enableAuto bool + autoRouter map[string]map[string]reflect.Type //key:controller key:method value:reflect.type + contextBuffer chan *beecontext.Context } func NewControllerRegistor() *ControllerRegistor { return &ControllerRegistor{ - routers: make([]*controllerInfo, 0), - autoRouter: make(map[string]map[string]reflect.Type), - filters: make(map[int][]*FilterRouter), + routers: make([]*controllerInfo, 0), + autoRouter: make(map[string]map[string]reflect.Type), + filters: make(map[int][]*FilterRouter), + contextBuffer: make(chan *beecontext.Context, 100), } } @@ -433,15 +434,40 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) w := &responseWriter{writer: rw} w.Header().Set("Server", BeegoServerName) - context := &beecontext.Context{ - ResponseWriter: w, - Request: r, - Input: beecontext.NewInput(r), - Output: beecontext.NewOutput(w), - } - context.Output.Context = context - context.Output.EnableGzip = EnableGzip + // init context + var context *beecontext.Context + select { + case context = <-p.contextBuffer: + context.ResponseWriter = w + context.Request = r + context.Input.Request = r + default: + context = &beecontext.Context{ + ResponseWriter: w, + Request: r, + Input: beecontext.NewInput(r), + Output: beecontext.NewOutput(), + } + context.Output.Context = context + context.Output.EnableGzip = EnableGzip + } + + defer func() { + if context != nil { + select { + case p.contextBuffer <- context: + default: + } + } + }() + + if context.Input.IsWebsocket() { + context.ResponseWriter = rw + context.Output = beecontext.NewOutput(rw) + } + + // defined filter function do_filter := func(pos int) (started bool) { if p.enableFilter { if l, ok := p.filters[pos]; ok { @@ -460,11 +486,6 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) return false } - if context.Input.IsWebsocket() { - context.ResponseWriter = rw - context.Output = beecontext.NewOutput(rw) - } - // session init if SessionOn { context.Input.CruSession = GlobalSessions.SessionStart(w, r)