diff --git a/controller.go b/controller.go index 7dc5dde3..9224a9c8 100644 --- a/controller.go +++ b/controller.go @@ -35,9 +35,16 @@ const ( var ( // custom error when user stop request handler manually. USERSTOPRUN = errors.New("User stop run") - GlobalControllerRouter map[string]map[string]*Tree //pkgpath+controller:method:routertree + GlobalControllerRouter map[string]*ControllerComments //pkgpath+controller:comments ) +// store the comment for the controller method +type ControllerComments struct { + method string + router string + allowHTTPMethods []string +} + // Controller defines some basic http request handler operations, such as // http context, template and view, session and xsrf. type Controller struct { @@ -56,7 +63,7 @@ type Controller struct { AppController interface{} EnableRender bool EnableXSRF bool - Routers map[string]*Tree //method:routertree + methodMapping map[string]func() //method:routertree } // ControllerInterface is an interface to uniform all controller handler. @@ -74,7 +81,7 @@ type ControllerInterface interface { Render() error XsrfToken() string CheckXsrfCookie() bool - HandlerFunc(fn interface{}) + HandlerFunc(fn string) URLMapping() } @@ -90,7 +97,7 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin c.EnableRender = true c.EnableXSRF = true c.Data = ctx.Input.Data - c.Routers = make(map[string]*Tree) + c.methodMapping = make(map[string]func()) } // Prepare runs after Init before request function execution. @@ -139,9 +146,11 @@ func (c *Controller) Options() { } // call function fn -func (c *Controller) HandlerFunc(fn interface{}) { - if v, ok := fn.(func()); ok { +func (c *Controller) HandlerFunc(fnname string) { + if v, ok := c.methodMapping[fnname]; ok { v() + } else { + Error("call funcname not exist in the methodMapping: " + fnname) } } @@ -149,19 +158,8 @@ func (c *Controller) HandlerFunc(fn interface{}) { func (c *Controller) URLMapping() { } -func (c *Controller) Mapping(method, pattern string, fn func()) { - method = strings.ToLower(method) - if !utils.InSlice(method, HTTPMETHOD) && method != "*" { - Critical("add mapping method:" + method + " is a valid method") - return - } - if t, ok := c.Routers[method]; ok { - t.AddRouter(pattern, fn) - } else { - t = NewTree() - t.AddRouter(pattern, fn) - c.Routers[method] = t - } +func (c *Controller) Mapping(method string, fn func()) { + c.methodMapping[method] = fn } // Render sends the response with rendered template bytes as text/html type. diff --git a/namespace.go b/namespace.go index c0f10aab..722ec0e8 100644 --- a/namespace.go +++ b/namespace.go @@ -182,7 +182,15 @@ func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { //) func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { for _, ni := range ns { - n.handlers.routers.AddTree(ni.prefix, ni.handlers.routers) + for k, v := range ni.handlers.routers { + if t, ok := n.handlers.routers[k]; ok { + n.handlers.routers[k].AddTree(ni.prefix, v) + } else { + t = NewTree() + t.AddTree(ni.prefix, v) + n.handlers.routers[k] = t + } + } if n.handlers.enableFilter { for pos, filterList := range ni.handlers.filters { for _, mr := range filterList { @@ -201,7 +209,15 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { // support multi Namespace func AddNamespace(nl ...*Namespace) { for _, n := range nl { - BeeApp.Handlers.routers.AddTree(n.prefix, n.handlers.routers) + for k, v := range n.handlers.routers { + if t, ok := BeeApp.Handlers.routers[k]; ok { + BeeApp.Handlers.routers[k].AddTree(n.prefix, v) + } else { + t = NewTree() + t.AddTree(n.prefix, v) + BeeApp.Handlers.routers[k] = t + } + } if n.handlers.enableFilter { for pos, filterList := range n.handlers.filters { for _, mr := range filterList { diff --git a/parser.go b/parser.go index f02d1188..8afd6749 100644 --- a/parser.go +++ b/parser.go @@ -4,3 +4,43 @@ // @license http://github.com/astaxie/beego/blob/master/LICENSE // @authors astaxie package beego + +import ( + "os" + "path/filepath" +) + +var globalControllerRouter = `package routers + +import ( + "github.com/astaxie/beego" +) + +func init() { + {{.globalinfo}} +} +` + +func parserPkg(pkgpath string) error { + err := filepath.Walk(pkgpath, func(path string, info os.FileInfo, err error) error { + if err != nil { + Error("error scan app Controller source:", err) + return err + } + //if is normal file or name is temp skip + //directory is needed + if !info.IsDir() || info.Name() == "tmp" { + return nil + } + + //fileSet := token.NewFileSet() + //astPkgs, err := parser.ParseDir(fileSet, path, func(info os.FileInfo) bool { + // name := info.Name() + // return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") + //}, parser.ParseComments) + + return nil + }) + + return err +} diff --git a/router.go b/router.go index 372fef26..e1235b30 100644 --- a/router.go +++ b/router.go @@ -12,7 +12,9 @@ import ( "fmt" "net" "net/http" + "os" "path" + "path/filepath" "reflect" "runtime" "strconv" @@ -67,7 +69,7 @@ type controllerInfo struct { // ControllerRegistor containers registered router rules, controller handlers and filters. type ControllerRegistor struct { - routers *Tree + routers map[string]*Tree enableFilter bool filters map[int][]*FilterRouter } @@ -75,7 +77,7 @@ type ControllerRegistor struct { // NewControllerRegistor returns a new ControllerRegistor. func NewControllerRegistor() *ControllerRegistor { return &ControllerRegistor{ - routers: NewTree(), + routers: make(map[string]*Tree), filters: make(map[int][]*FilterRouter), } } @@ -120,17 +122,69 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM route.methods = methods route.routerType = routerTypeBeego route.controllerType = t - p.routers.AddRouter(pattern, route) + if len(methods) == 0 { + for _, m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + for k, _ := range methods { + if k == "*" { + for _, m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } + } +} + +func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) { + if t, ok := p.routers[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.routers[method] = t + } } // only when the Runmode is dev will generate router file in the router/auto.go from the controller // Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) func (p *ControllerRegistor) Include(cList ...ControllerInterface) { if RunMode == "dev" { + skip := make(map[string]bool, 10) for _, c := range cList { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() - t.PkgPath() + gopath := os.Getenv("GOPATH") + if gopath == "" { + panic("you are in dev mode. So please set gopath") + } + pkgpath := "" + + wgopath := filepath.SplitList(gopath) + for _, wg := range wgopath { + wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) + if utils.FileExists(wg) { + pkgpath = wg + break + } + } + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath) + } + } + } + } + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + key := t.PkgPath() + ":" + t.Name() + if comm, ok := GlobalControllerRouter[key]; ok { + p.Add(comm.router, c, strings.Join(comm.allowHTTPMethods, ",")+":"+comm.method) } } } @@ -228,7 +282,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) { methods[method] = method } route.methods = methods - p.routers.AddRouter(pattern, route) + for k, _ := range methods { + if k == "*" { + for _, m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } } // add user defined Handler @@ -241,7 +303,9 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ... pattern = path.Join(pattern, "?:all") } } - p.routers.AddRouter(pattern, route) + for _, m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } } // Add auto router to ControllerRegistor. @@ -270,7 +334,9 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) route.methods = map[string]string{"*": rt.Method(i).Name} route.controllerType = ct pattern := path.Join(prefix, controllerName, strings.ToLower(rt.Method(i).Name), "*") - p.routers.AddRouter(pattern, route) + for _, m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } } } } @@ -317,12 +383,13 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...string) string { } controllName := strings.Join(paths[:len(paths)-1], ".") methodName := paths[len(paths)-1] - ok, url := p.geturl(p.routers, "/", controllName, methodName, params) - if ok { - return url - } else { - return "" + for _, t := range p.routers { + ok, url := p.geturl(t, "/", controllName, methodName, params) + if ok { + return url + } } + return "" } func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string) (bool, string) { @@ -436,6 +503,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) starttime := time.Now() requestPath := r.URL.Path + method := strings.ToLower(r.Method) var runrouter reflect.Type var findrouter bool var runMethod string @@ -485,7 +553,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) }() } - if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { + if !utils.InSlice(method, HTTPMETHOD) { http.Error(w, "Method Not Allowed", 405) goto Admin } @@ -512,18 +580,21 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } if !findrouter { - runObject, p := p.routers.Match(requestPath) - if r, ok := runObject.(*controllerInfo); ok { - routerInfo = r - findrouter = true - if splat, ok := p[":splat"]; ok { - splatlist := strings.Split(splat, "/") - for k, v := range splatlist { - p[strconv.Itoa(k)] = v + if t, ok := p.routers[method]; ok { + runObject, p := t.Match(requestPath) + if r, ok := runObject.(*controllerInfo); ok { + routerInfo = r + findrouter = true + if splat, ok := p[":splat"]; ok { + splatlist := strings.Split(splat, "/") + for k, v := range splatlist { + p[strconv.Itoa(k)] = v + } } + context.Input.Params = p } - context.Input.Params = p } + } //if no matches to url, throw a not found exception diff --git a/router_test.go b/router_test.go index e0fa70f5..267e8569 100644 --- a/router_test.go +++ b/router_test.go @@ -291,6 +291,17 @@ func (a *AdminController) Get() { a.Ctx.WriteString("hello") } +func TestRouterFunc(t *testing.T) { + mux := NewControllerRegistor() + mux.Get("/action", beegoFilterFunc) + mux.Post("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + mux.ServeHTTP(rw, r) + if rw.Body.String() != "hello" { + t.Errorf("TestRouterFunc can't run") + } +} + func BenchmarkFunc(b *testing.B) { mux := NewControllerRegistor() mux.Get("/action", beegoFilterFunc)