From 9b79437778cf2b78f00d71c963cfc4732a01b39e Mon Sep 17 00:00:00 2001 From: eyalpost Date: Tue, 25 Apr 2017 16:00:49 +0300 Subject: [PATCH] all types working + controller comments generation --- param/conv.go | 58 +++++++----- param/methodparams.go | 72 ++++++++------- param/options.go | 6 ++ param/parsers.go | 57 ++++++++++++ parser.go | 209 +++++++++++++++++++++++++++++++++++------- router.go | 3 + 6 files changed, 317 insertions(+), 88 deletions(-) diff --git a/param/conv.go b/param/conv.go index e909c4ea..4938f45f 100644 --- a/param/conv.go +++ b/param/conv.go @@ -9,39 +9,51 @@ import ( ) func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { - var strValue string - var reflectValue reflect.Value - switch param.location { - case body: - strValue = string(ctx.Input.RequestBody) - case header: - strValue = ctx.Input.Header(param.name) - default: - strValue = ctx.Input.Query(param.name) - } - - if strValue == "" { + paramValue := getParamValue(param, ctx) + if paramValue == "" { if param.required { ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) } else { - strValue = param.defValue + paramValue = param.defValue } } - if strValue == "" { - reflectValue = reflect.Zero(paramType) + + reflectValue, err := parseValue(paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.location { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + // if strValue == "" && strings.Contains(param.name, "_") { //magically handle X-Headers? + // strValue = ctx.Input.Header(strings.Replace(param.name, "_", "-", -1)) + // } + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil } else { - value, err := param.parser.parse(strValue, paramType) + value, err := parse(paramValue, paramType) if err != nil { - logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %s, Error: %s", param.name, paramType, strValue, err)) - ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %s to type %s", param.name, strValue, paramType)) + return result, err } - reflectValue, err = safeConvert(reflect.ValueOf(value), paramType) - if err != nil { - panic(err) - } + return safeConvert(reflect.ValueOf(value), paramType) } - return reflectValue } func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { diff --git a/param/methodparams.go b/param/methodparams.go index a3e25a55..23fd6661 100644 --- a/param/methodparams.go +++ b/param/methodparams.go @@ -1,9 +1,13 @@ package param +import ( + "fmt" + "strings" +) + //Keeps param information to be auto passed to controller methods type MethodParam struct { name string - parser paramParser location paramLocation required bool defValue string @@ -13,45 +17,51 @@ type paramLocation byte const ( param paramLocation = iota + path body header ) -type MethodParamOption func(*MethodParam) - -func Bool(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, boolParser{}, opts) -} - -func String(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, stringParser{}, opts) -} - -func Int(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, intParser{}, opts) -} - -func Float(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, floatParser{}, opts) -} - -func Time(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, timeParser{}, opts) -} - -func Json(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, jsonParser{}, opts) -} - -func AsSlice(param *MethodParam) *MethodParam { - param.parser = sliceParser(param.parser) - return param +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) } func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { - param = &MethodParam{name: name, parser: parser} + param = &MethodParam{name: name} for _, option := range opts { option(param) } return } + +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.location { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/param/options.go b/param/options.go index 0692f9d1..0013c31e 100644 --- a/param/options.go +++ b/param/options.go @@ -4,6 +4,8 @@ import ( "fmt" ) +type MethodParamOption func(*MethodParam) + var IsRequired MethodParamOption = func(p *MethodParam) { p.required = true } @@ -12,6 +14,10 @@ var InHeader MethodParamOption = func(p *MethodParam) { p.location = header } +var InPath MethodParamOption = func(p *MethodParam) { + p.location = path +} + var InBody MethodParamOption = func(p *MethodParam) { p.location = body } diff --git a/param/parsers.go b/param/parsers.go index cfa1a981..64cabe49 100644 --- a/param/parsers.go +++ b/param/parsers.go @@ -12,6 +12,45 @@ type paramParser interface { parse(value string, toType reflect.Type) (interface{}, error) } +func parse(value string, t reflect.Type) (interface{}, error) { + parser := getParser(t) + return parser.parse(value, t) +} + +func getParser(t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + elemParser := getParser(t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + type parserFunc func(value string, toType reflect.Type) (interface{}, error) func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { @@ -92,3 +131,21 @@ func sliceParser(elemParser paramParser) paramParser { return result.Interface(), nil }) } + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/parser.go b/parser.go index d40ee3ce..43e052bd 100644 --- a/parser.go +++ b/parser.go @@ -24,10 +24,14 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" "sort" + "strconv" "strings" + "unicode" "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/param" "github.com/astaxie/beego/utils" ) @@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers import ( "github.com/astaxie/beego" + "github.com/astaxie/beego/param" ) func init() { @@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { if specDecl.Recv != nil { exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser if ok { - parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) } } } @@ -93,44 +98,168 @@ func parserPkg(pkgRealpath, pkgpath string) error { return nil } -func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error { - if comments != nil && comments.List != nil { - for _, c := range comments.List { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - elements := strings.TrimLeft(t, "@router ") - e1 := strings.SplitN(elements, " ", 2) - if len(e1) < 1 { - return errors.New("you should has router information") - } - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = funcName - cc.Router = e1[0] - if len(e1) == 2 && e1[1] != "" { - e1 = strings.SplitN(e1[1], " ", 2) - if len(e1) >= 1 { - cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",") - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - if len(e1) == 2 && e1[1] != "" { - keyval := strings.Split(strings.Trim(e1[1], "[]"), " ") - for _, kv := range keyval { - kk := strings.Split(kv, ":") - cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]}) - } - } - genInfoList[key] = append(genInfoList[key], cc) - } +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComment, err := parseComment(f.Doc.List) + if err != nil { + return err } + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + genInfoList[key] = append(genInfoList[key], cc) + } return nil } +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + methodParam := buildMethodParam(fparam, pc) + result = append(result, methodParam) + } + return result +} + +func buildMethodParam(fparam *ast.Field, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + name := fparam.Names[0].Name + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { + pc = &parsedComment{} + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + } else { + return nil, errors.New("Router information is missing") + } + } else if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.Split(pv[0], "=") + funcParamName := names[0] + if len(names) > 1 { + p.name = names[1] + } else { + p.name = funcParamName + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + if pc.params == nil { + pc.params = map[string]parsedParam{} + } + pc.params[funcParamName] = p + } + } + return +} + +// direct copy from bee\g_docs.go +// analisys params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range []rune(str) { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + func genRouterCode(pkgRealpath string) { os.Mkdir(getRouterDir(pkgRealpath), 0755) logs.Info("generate router from comments") @@ -163,12 +292,24 @@ func genRouterCode(pkgRealpath string) { } params = strings.TrimRight(params, ",") + "}" } + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" globalinfo = globalinfo + ` beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.ControllerComments{ Method: "` + strings.TrimSpace(c.Method) + `", ` + "Router: `" + c.Router + "`" + `, AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, Params: ` + params + `}) ` } diff --git a/router.go b/router.go index 57479248..546db22d 100644 --- a/router.go +++ b/router.go @@ -908,6 +908,9 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex response.RenderMethodResult(resultValue, context) } } + if !context.ResponseWriter.Started && context.Output.Status == 0 { + context.Output.SetStatus(200) + } } // FindRouter Find Router info for URL