diff --git a/router.go b/router.go index b3384473..ed3287f9 100644 --- a/router.go +++ b/router.go @@ -22,6 +22,7 @@ type controllerInfo struct { params map[int]string controllerType reflect.Type methods map[string]string + hasMethod bool } type userHandler struct { @@ -34,8 +35,12 @@ type userHandler struct { type ControllerRegistor struct { routers []*controllerInfo fixrouters []*controllerInfo + enableFilter bool filters []http.HandlerFunc + afterFilters []http.HandlerFunc + enableUser bool userHandlers map[string]*userHandler + enableAuto bool autoRouter map[string]map[string]reflect.Type //key:controller key:method value:reflect.type } @@ -130,6 +135,9 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM route.pattern = pattern route.controllerType = t route.methods = methods + if len(methods) > 0 { + route.hasMethod = true + } p.fixrouters = append(p.fixrouters, route) } else { // add regexp routers //recreate the url pattern, with parameters replaced @@ -149,12 +157,16 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM route.params = params route.pattern = pattern route.methods = methods + if len(methods) > 0 { + route.hasMethod = true + } route.controllerType = t p.routers = append(p.routers, route) } } func (p *ControllerRegistor) AddAuto(c ControllerInterface) { + p.enableAuto = true reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() ct := reflect.Indirect(reflectVal).Type() @@ -170,6 +182,7 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) { } func (p *ControllerRegistor) AddHandler(pattern string, c http.Handler) { + p.enableUser = true parts := strings.Split(pattern, "/") j := 0 @@ -217,6 +230,7 @@ func (p *ControllerRegistor) AddHandler(pattern string, c http.Handler) { // Filter adds the middleware filter. func (p *ControllerRegistor) Filter(filter http.HandlerFunc) { + p.enableFilter = true p.filters = append(p.filters, filter) } @@ -332,52 +346,48 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) r.ParseMultipartForm(MaxMemory) //user defined Handler - for pattern, c := range p.userHandlers { - if c.regex == nil && pattern == requestPath { + if p.enableUser { + for pattern, c := range p.userHandlers { + if c.regex == nil && pattern == requestPath { + c.h.ServeHTTP(rw, r) + return + } else if c.regex == nil { + continue + } + + //check if Route pattern matches url + if !c.regex.MatchString(requestPath) { + continue + } + + //get submatches (params) + matches := c.regex.FindStringSubmatch(requestPath) + + //double check that the Route matches the URL pattern. + if len(matches[0]) != len(requestPath) { + continue + } + + if len(c.params) > 0 { + //add url parameters to the query param map + values := r.URL.Query() + for i, match := range matches[1:] { + values.Add(c.params[i], match) + r.Form.Add(c.params[i], match) + params[c.params[i]] = match + } + //reassemble query params and add to RawQuery + r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery + //r.URL.RawQuery = url.Values(values).Encode() + } c.h.ServeHTTP(rw, r) return - } else if c.regex == nil { - continue } - - //check if Route pattern matches url - if !c.regex.MatchString(requestPath) { - continue - } - - //get submatches (params) - matches := c.regex.FindStringSubmatch(requestPath) - - //double check that the Route matches the URL pattern. - if len(matches[0]) != len(requestPath) { - continue - } - - if len(c.params) > 0 { - //add url parameters to the query param map - values := r.URL.Query() - for i, match := range matches[1:] { - values.Add(c.params[i], match) - r.Form.Add(c.params[i], match) - params[c.params[i]] = match - } - //reassemble query params and add to RawQuery - r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery - //r.URL.RawQuery = url.Values(values).Encode() - } - c.h.ServeHTTP(rw, r) - return } //first find path from the fixrouters to Improve Performance for _, route := range p.fixrouters { n := len(requestPath) - //route like "/" - //if n == 1 { - // else { - // continue - // } - //} if requestPath == route.pattern { runrouter = route findrouter = true @@ -392,6 +402,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } } + //find regex's router if !findrouter { //find a matching Route for _, route := range p.routers { @@ -466,64 +477,93 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) //if response has written,yes don't run next if !w.started { if r.Method == "GET" { - if m, ok := runrouter.methods["get"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["get"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Get") + } } else { method = vc.MethodByName("Get") } method.Call(in) } else if r.Method == "HEAD" { - if m, ok := runrouter.methods["head"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["head"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Head") + } } else { method = vc.MethodByName("Head") } + method.Call(in) } else if r.Method == "DELETE" || (r.Method == "POST" && r.Form.Get("_method") == "delete") { - if m, ok := runrouter.methods["delete"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["delete"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Delete") + } } else { method = vc.MethodByName("Delete") } method.Call(in) } else if r.Method == "PUT" || (r.Method == "POST" && r.Form.Get("_method") == "put") { - if m, ok := runrouter.methods["put"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["put"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Put") + } } else { method = vc.MethodByName("Put") } method.Call(in) } else if r.Method == "POST" { - if m, ok := runrouter.methods["post"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["post"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Post") + } } else { method = vc.MethodByName("Post") } method.Call(in) } else if r.Method == "PATCH" { - if m, ok := runrouter.methods["patch"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["patch"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Patch") + } } else { method = vc.MethodByName("Patch") } method.Call(in) } else if r.Method == "OPTIONS" { - if m, ok := runrouter.methods["options"]; ok { - method = vc.MethodByName(m) - } else if m, ok = runrouter.methods["*"]; ok { - method = vc.MethodByName(m) + if runrouter.hasMethod { + if m, ok := runrouter.methods["options"]; ok { + method = vc.MethodByName(m) + } else if m, ok = runrouter.methods["*"]; ok { + method = vc.MethodByName(m) + } else { + method = vc.MethodByName("Options") + } } else { method = vc.MethodByName("Options") } @@ -553,72 +593,75 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) //start autorouter - if !findrouter { - for cName, methodmap := range p.autoRouter { + if p.enableAuto { + if !findrouter { + for cName, methodmap := range p.autoRouter { - if strings.ToLower(requestPath) == "/"+cName { - http.Redirect(w, r, requestPath+"/", 301) - return - } + if strings.ToLower(requestPath) == "/"+cName { + http.Redirect(w, r, requestPath+"/", 301) + return + } - if strings.ToLower(requestPath) == "/"+cName+"/" { - requestPath = requestPath + "index" - } - if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/") { - for mName, controllerType := range methodmap { - if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/"+strings.ToLower(mName)) { - //parse params - otherurl := requestPath[len("/"+cName+"/"+strings.ToLower(mName)):] - if len(otherurl) > 1 { - plist := strings.Split(otherurl, "/") - for k, v := range plist[1:] { - params[strconv.Itoa(k)] = v + if strings.ToLower(requestPath) == "/"+cName+"/" { + requestPath = requestPath + "index" + } + if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/") { + for mName, controllerType := range methodmap { + if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/"+strings.ToLower(mName)) { + //parse params + otherurl := requestPath[len("/"+cName+"/"+strings.ToLower(mName)):] + if len(otherurl) > 1 { + plist := strings.Split(otherurl, "/") + for k, v := range plist[1:] { + params[strconv.Itoa(k)] = v + } } - } - //Invoke the request handler - vc := reflect.New(controllerType) + //Invoke the request handler + vc := reflect.New(controllerType) - //call the controller init function - init := vc.MethodByName("Init") - in := make([]reflect.Value, 2) - ct := &Context{ResponseWriter: w, Request: r, Params: params, RequestBody: requestbody} + //call the controller init function + init := vc.MethodByName("Init") + in := make([]reflect.Value, 2) + ct := &Context{ResponseWriter: w, Request: r, Params: params, RequestBody: requestbody} - in[0] = reflect.ValueOf(ct) - in[1] = reflect.ValueOf(controllerType.Name()) - init.Call(in) - //call prepare function - in = make([]reflect.Value, 0) - method := vc.MethodByName("Prepare") - method.Call(in) - method = vc.MethodByName(mName) - method.Call(in) - //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if EnableXSRF { - method = vc.MethodByName("XsrfToken") + in[0] = reflect.ValueOf(ct) + in[1] = reflect.ValueOf(controllerType.Name()) + init.Call(in) + //call prepare function + in = make([]reflect.Value, 0) + method := vc.MethodByName("Prepare") method.Call(in) - if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" || - (r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) { - method = vc.MethodByName("CheckXsrfCookie") + method = vc.MethodByName(mName) + method.Call(in) + //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf + if EnableXSRF { + method = vc.MethodByName("XsrfToken") method.Call(in) + if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" || + (r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) { + method = vc.MethodByName("CheckXsrfCookie") + method.Call(in) + } } - } - if !w.started { - if AutoRender { - method = vc.MethodByName("Render") - method.Call(in) + if !w.started { + if AutoRender { + method = vc.MethodByName("Render") + method.Call(in) + } } + method = vc.MethodByName("Finish") + method.Call(in) + method = vc.MethodByName("Destructor") + method.Call(in) + // set find + findrouter = true } - method = vc.MethodByName("Finish") - method.Call(in) - method = vc.MethodByName("Destructor") - method.Call(in) - // set find - findrouter = true } } } } } + //if no matches to url, throw a not found exception if !findrouter { if h, ok := ErrorMaps["404"]; ok {