diff --git a/controller.go b/controller.go index 8be43a33..5ad715ff 100644 --- a/controller.go +++ b/controller.go @@ -47,10 +47,37 @@ var ( GlobalControllerRouter = make(map[string][]ControllerComments) ) +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + // ControllerComments store the comment for the controller method type ControllerComments struct { Method string Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments AllowHTTPMethods []string Params []map[string]string MethodParams []*param.MethodParam diff --git a/parser.go b/parser.go index 2ac48b85..a8690274 100644 --- a/parser.go +++ b/parser.go @@ -39,7 +39,7 @@ var globalRouterTemplate = `package routers import ( "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/context/param"{{.globalimport}} ) func init() { @@ -52,6 +52,22 @@ var ( commentFilename string pkgLastupdate map[string]int64 genInfoList map[string][]ControllerComments + + routerHooks = map[string]int{ + "beego.BeforeStatic": BeforeStatic, + "beego.BeforeRouter": BeforeRouter, + "beego.BeforeExec": BeforeExec, + "beego.AfterExec": AfterExec, + "beego.FinishRouter": FinishRouter, + } + + routerHooksMapping = map[int]string{ + BeforeStatic: "beego.BeforeStatic", + BeforeRouter: "beego.BeforeRouter", + BeforeExec: "beego.BeforeExec", + AfterExec: "beego.AfterExec", + FinishRouter: "beego.FinishRouter", + } ) const commentPrefix = "commentsRouter_" @@ -102,6 +118,20 @@ type parsedComment struct { routerPath string methods []string params map[string]parsedParam + filters []parsedFilter + imports []parsedImport +} + +type parsedImport struct { + importPath string + importAlias string +} + +type parsedFilter struct { + pattern string + pos int + filter string + params []bool } type parsedParam struct { @@ -126,6 +156,8 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { cc.Router = parsedComment.routerPath cc.AllowHTTPMethods = parsedComment.methods cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + cc.FilterComments = buildFilters(parsedComment.filters) + cc.ImportComments = buildImports(parsedComment.imports) genInfoList[key] = append(genInfoList[key], cc) } } @@ -133,6 +165,48 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { return nil } +func buildImports(pis []parsedImport) []*ControllerImportComments { + var importComments []*ControllerImportComments + + for _, pi := range pis { + importComments = append(importComments, &ControllerImportComments{ + ImportPath: pi.importPath, + ImportAlias: pi.importAlias, + }) + } + + return importComments +} + +func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { + var filterComments []*ControllerFilterComments + + for _, pf := range pfs { + var ( + returnOnOutput bool + resetParams bool + ) + + if len(pf.params) >= 1 { + returnOnOutput = pf.params[0] + } + + if len(pf.params) >= 2 { + resetParams = pf.params[1] + } + + filterComments = append(filterComments, &ControllerFilterComments{ + Filter: pf.filter, + Pattern: pf.pattern, + Pos: pf.pos, + ReturnOnOutput: returnOnOutput, + ResetParams: resetParams, + }) + } + + return filterComments +} + func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { result := make([]*param.MethodParam, 0, len(funcParams)) for _, fparam := range funcParams { @@ -181,6 +255,8 @@ var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { pcs = []*parsedComment{} params := map[string]parsedParam{} + filters := []parsedFilter{} + imports := []parsedImport{} for _, c := range lines { t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) @@ -209,9 +285,69 @@ func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { } } + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Import") { + iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) + if len(iv) == 0 || len(iv) > 2 { + logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") + continue + } + + p := parsedImport{} + p.importPath = iv[0] + + if len(iv) == 2 { + p.importAlias = iv[1] + } + + imports = append(imports, p) + } + } + +filterLoop: + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Filter") { + fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) + if len(fv) < 3 { + logs.Error("Invalid @Filter format. Needs at least 3 parameters") + continue filterLoop + } + + p := parsedFilter{} + p.pattern = fv[0] + posName := fv[1] + if pos, exists := routerHooks[posName]; exists { + p.pos = pos + } else { + logs.Error("Invalid @Filter pos: ", posName) + continue filterLoop + } + + p.filter = fv[2] + fvParams := fv[3:] + for _, fvParam := range fvParams { + switch fvParam { + case "true": + p.params = append(p.params, true) + case "false": + p.params = append(p.params, false) + default: + logs.Error("Invalid @Filter param: ", fvParam) + continue filterLoop + } + } + + filters = append(filters, p) + } + } + for _, c := range lines { var pc = &parsedComment{} pc.params = params + pc.filters = filters + pc.imports = imports t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) if strings.HasPrefix(t, "@router") { @@ -276,8 +412,9 @@ func genRouterCode(pkgRealpath string) { os.Mkdir(getRouterDir(pkgRealpath), 0755) logs.Info("generate router from comments") var ( - globalinfo string - sortKey []string + globalinfo string + globalimport string + sortKey []string ) for k := range genInfoList { sortKey = append(sortKey, k) @@ -295,6 +432,7 @@ func genRouterCode(pkgRealpath string) { } allmethod = strings.TrimRight(allmethod, ",") + "}" } + params := "nil" if len(c.Params) > 0 { params = "[]map[string]string{" @@ -305,6 +443,7 @@ func genRouterCode(pkgRealpath string) { } params = strings.TrimRight(params, ",") + "}" } + methodParams := "param.Make(" if len(c.MethodParams) > 0 { lines := make([]string, 0, len(c.MethodParams)) @@ -316,24 +455,66 @@ func genRouterCode(pkgRealpath string) { ",\n " } methodParams += ")" + + imports := "" + if len(c.ImportComments) > 0 { + for _, i := range c.ImportComments { + if i.ImportAlias != "" { + imports += fmt.Sprintf(` + %s "%s"`, i.ImportAlias, i.ImportPath) + } else { + imports += fmt.Sprintf(` + "%s"`, i.ImportPath) + } + } + } + + filters := "" + if len(c.FilterComments) > 0 { + for _, f := range c.FilterComments { + filters += fmt.Sprintf(` &beego.ControllerFilter{ + Pattern: "%s", + Pos: %s, + Filter: %s, + ReturnOnOutput: %v, + ResetParams: %v, + },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) + } + } + + if filters == "" { + filters = "nil" + } else { + filters = fmt.Sprintf(`[]*beego.ControllerFilter{ +%s + }`, filters) + } + + globalimport = imports + globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + "Router: `" + c.Router + "`" + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Params: ` + params + `}) + beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], + beego.ControllerComments{ + Method: "` + strings.TrimSpace(c.Method) + `", + ` + "Router: `" + c.Router + "`" + `, + AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, + Filters: ` + filters + `, + Params: ` + params + `}) ` } } + if globalinfo != "" { f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) if err != nil { panic(err) } defer f.Close() - f.WriteString(strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)) + + content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) + f.WriteString(content) } } diff --git a/router.go b/router.go index 2e0cea25..7b449885 100644 --- a/router.go +++ b/router.go @@ -43,7 +43,7 @@ const ( ) const ( - routerTypeBeego = iota + routerTypeBeego = iota routerTypeRESTFul routerTypeHandler ) @@ -277,6 +277,10 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { key := t.PkgPath() + ":" + t.Name() if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + } + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) } } @@ -877,7 +881,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } Admin: -//admin module record QPS + //admin module record QPS statusCode := context.ResponseWriter.Status if statusCode == 0 {