diff --git a/generate/swaggergen/g_docs.go b/generate/swaggergen/g_docs.go index 642c263..7259a89 100644 --- a/generate/swaggergen/g_docs.go +++ b/generate/swaggergen/g_docs.go @@ -24,8 +24,6 @@ import ( "os" "path" "path/filepath" - "reflect" - "regexp" "runtime" "strconv" "strings" @@ -37,6 +35,8 @@ import ( "github.com/astaxie/beego/utils" beeLogger "github.com/beego/bee/logger" bu "github.com/beego/bee/utils" + "github.com/go-openapi/spec" + "github.com/wy-z/tspec/tspec" ) const ( @@ -51,9 +51,35 @@ var pkgCache map[string]struct{} //pkg:controller:function:comments comments: ke var controllerComments map[string]string var importlist map[string]string var controllerList map[string]map[string]*swagger.Item //controllername Paths items -var modelsList map[string]map[string]swagger.Schema var rootapi swagger.Swagger -var astPkgs []*ast.Package + +var tparser *tspec.Parser +var controllerPkg *ast.Package + +func convertSpecDefinitions(specDefs spec.Definitions) (defs map[string]swagger.Schema, err error) { + bytes, err := json.Marshal(specDefs) + if err != nil { + return + } + defs = make(map[string]swagger.Schema) + err = json.Unmarshal(bytes, &defs) + if err != nil { + return + } + return +} + +func parseModel(pkg *ast.Package, typeStr string) (typeID string, err error) { + if pkg == nil { + panic("pkg can not be nil") + } + schema, err := tparser.Parse(pkg, typeStr) + if err != nil { + return + } + typeID = schema.ID + return +} // refer to builtin.go var basicTypes = map[string]string{ @@ -80,70 +106,15 @@ var basicTypes = map[string]string{ "time.Time": "string:string", } -var stdlibObject = map[string]string{ - "&{time Time}": "time.Time", -} - func init() { pkgCache = make(map[string]struct{}) controllerComments = make(map[string]string) importlist = make(map[string]string) controllerList = make(map[string]map[string]*swagger.Item) - modelsList = make(map[string]map[string]swagger.Schema) - astPkgs = make([]*ast.Package, 0) -} - -func ParsePackagesFromDir(dirpath string) { - c := make(chan error) - - go func() { - filepath.Walk(dirpath, func(fpath string, fileInfo os.FileInfo, err error) error { - if err != nil { - return nil - } - if !fileInfo.IsDir() { - return nil - } - - // 7 is length of 'vendor' (6) + length of file path separator (1) - // so we skip dir 'vendor' which is directly under dirpath - if !(len(fpath) == len(dirpath)+7 && strings.HasSuffix(fpath, "vendor")) && - !strings.Contains(fpath, "tests") && - !(len(fpath) > len(dirpath) && fpath[len(dirpath)+1] == '.') { - err = parsePackageFromDir(fpath) - if err != nil { - // Send the error to through the channel and continue walking - c <- fmt.Errorf("Error while parsing directory: %s", err.Error()) - return nil - } - } - return nil - }) - close(c) - }() - - for err := range c { - beeLogger.Log.Warnf("%s", err) - } -} - -func parsePackageFromDir(path string) error { - fileSet := token.NewFileSet() - folderPkgs, err := parser.ParseDir(fileSet, path, func(info os.FileInfo) bool { - name := info.Name() - return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") - }, parser.ParseComments) - if err != nil { - return err - } - - for _, v := range folderPkgs { - astPkgs = append(astPkgs, v) - } - - return nil + tparser = tspec.NewParser() } +// GenerateDocs ... func GenerateDocs(curpath string) { fset := token.NewFileSet() @@ -307,6 +278,12 @@ func GenerateDocs(curpath string) { } } } + defs, err := convertSpecDefinitions(tparser.Definitions()) + if err != nil { + panic(err) + } + rootapi.Definitions = defs + os.Mkdir(path.Join(curpath, "swagger"), 0755) fd, err := os.Create(path.Join(curpath, "swagger", "swagger.json")) if err != nil { @@ -442,6 +419,10 @@ func analyseControllerPkg(vendorPath, localName, pkgpath string) { beeLogger.Log.Fatalf("Error while parsing dir at '%s': %s", pkgpath, err) } for _, pkg := range astPkgs { + if pkg.Name == "controllers" { + controllerPkg = pkg + } + for _, fl := range pkg.Files { for _, d := range fl.Decls { switch specDecl := d.(type) { @@ -557,13 +538,11 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { schema.Type = typeFormat[0] schema.Format = typeFormat[1] } else { - m, mod, realTypes := getModel(schemaName) - schema.Ref = "#/definitions/" + m - if _, ok := modelsList[pkgpath+controllerName]; !ok { - modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema) + typeID, err := parseModel(controllerPkg, schemaName) + if err != nil { + beeLogger.Log.Fatalf("failed to parse model %s: %s", schemaName, err) } - modelsList[pkgpath+controllerName][schemaName] = mod - appendModels(pkgpath, controllerName, realTypes) + schema.Ref = "#/definitions/" + typeID } if isArray { rs.Schema = &swagger.Schema{ @@ -613,15 +592,13 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { pp := strings.Split(p[2], ".") typ := pp[len(pp)-1] if len(pp) >= 2 { - m, mod, realTypes := getModel(p[2]) + typeID, err := parseModel(controllerPkg, p[2]) + if err != nil { + beeLogger.Log.Fatalf("failed to parse model %s: %s", p[2], err) + } para.Schema = &swagger.Schema{ - Ref: "#/definitions/" + m, + Ref: "#/definitions/" + typeID, } - if _, ok := modelsList[pkgpath+controllerName]; !ok { - modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema) - } - modelsList[pkgpath+controllerName][typ] = mod - appendModels(pkgpath, controllerName, realTypes) } else { if typ == "auto" { typ = paramType @@ -754,15 +731,13 @@ func setParamType(para *swagger.Parameter, typ string, pkgpath, controllerName s paraType = typeFormat[0] paraFormat = typeFormat[1] } else { - m, mod, realTypes := getModel(typ) + typeID, err := parseModel(controllerPkg, typ) + if err != nil { + beeLogger.Log.Fatalf("failed to parse model %s: %s", typ, err) + } para.Schema = &swagger.Schema{ - Ref: "#/definitions/" + m, + Ref: "#/definitions/" + typeID, } - if _, ok := modelsList[pkgpath+controllerName]; !ok { - modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema) - } - modelsList[pkgpath+controllerName][typ] = mod - appendModels(pkgpath, controllerName, realTypes) } if isArray { para.Type = "array" @@ -857,227 +832,6 @@ func getparams(str string) []string { return r } -func getModel(str string) (objectname string, m swagger.Schema, realTypes []string) { - strs := strings.Split(str, ".") - objectname = strs[len(strs)-1] - packageName := "" - m.Type = "object" - for _, pkg := range astPkgs { - if strs[0] == pkg.Name { - for _, fl := range pkg.Files { - for k, d := range fl.Scope.Objects { - if d.Kind == ast.Typ { - if k != objectname { - continue - } - packageName = pkg.Name - parseObject(d, k, &m, &realTypes, astPkgs, pkg.Name) - } - } - } - } - } - if m.Title == "" { - beeLogger.Log.Warnf("Cannot find the object: %s", str) - // TODO remove when all type have been supported - //os.Exit(1) - } - if len(rootapi.Definitions) == 0 { - rootapi.Definitions = make(map[string]swagger.Schema) - } - objectname = packageName + "." + objectname - rootapi.Definitions[objectname] = m - return -} - -func parseObject(d *ast.Object, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) { - ts, ok := d.Decl.(*ast.TypeSpec) - if !ok { - beeLogger.Log.Fatalf("Unknown type without TypeSec: %v\n", d) - } - // TODO support other types, such as `ArrayType`, `MapType`, `InterfaceType` etc... - st, ok := ts.Type.(*ast.StructType) - if !ok { - return - } - m.Title = k - if st.Fields.List != nil { - m.Properties = make(map[string]swagger.Propertie) - for _, field := range st.Fields.List { - isSlice, realType, sType := typeAnalyser(field) - if (isSlice && isBasicType(realType)) || sType == "object" { - if len(strings.Split(realType, " ")) > 1 { - realType = strings.Replace(realType, " ", ".", -1) - realType = strings.Replace(realType, "&", "", -1) - realType = strings.Replace(realType, "{", "", -1) - realType = strings.Replace(realType, "}", "", -1) - } else { - realType = packageName + "." + realType - } - } - *realTypes = append(*realTypes, realType) - mp := swagger.Propertie{} - if isSlice { - mp.Type = "array" - if isBasicType(strings.Replace(realType, "[]", "", -1)) { - typeFormat := strings.Split(sType, ":") - mp.Items = &swagger.Propertie{ - Type: typeFormat[0], - Format: typeFormat[1], - } - } else { - mp.Items = &swagger.Propertie{ - Ref: "#/definitions/" + realType, - } - } - } else { - if sType == "object" { - mp.Ref = "#/definitions/" + realType - } else if isBasicType(realType) { - typeFormat := strings.Split(sType, ":") - mp.Type = typeFormat[0] - mp.Format = typeFormat[1] - } else if realType == "map" { - typeFormat := strings.Split(sType, ":") - mp.AdditionalProperties = &swagger.Propertie{ - Type: typeFormat[0], - Format: typeFormat[1], - } - } - } - if field.Names != nil { - - // set property name as field name - var name = field.Names[0].Name - - // if no tag skip tag processing - if field.Tag == nil { - m.Properties[name] = mp - continue - } - - var tagValues []string - - stag := reflect.StructTag(strings.Trim(field.Tag.Value, "`")) - - defaultValue := stag.Get("doc") - if defaultValue != "" { - r, _ := regexp.Compile(`default\((.*)\)`) - if r.MatchString(defaultValue) { - res := r.FindStringSubmatch(defaultValue) - mp.Default = str2RealType(res[1], realType) - - } else { - beeLogger.Log.Warnf("Invalid default value: %s", defaultValue) - } - } - - tag := stag.Get("json") - - if tag != "" { - tagValues = strings.Split(tag, ",") - } - - // dont add property if json tag first value is "-" - if len(tagValues) == 0 || tagValues[0] != "-" { - - // set property name to the left most json tag value only if is not omitempty - if len(tagValues) > 0 && tagValues[0] != "omitempty" { - name = tagValues[0] - } - - if thrifttag := stag.Get("thrift"); thrifttag != "" { - ts := strings.Split(thrifttag, ",") - if ts[0] != "" { - name = ts[0] - } - } - if required := stag.Get("required"); required != "" { - m.Required = append(m.Required, name) - } - if desc := stag.Get("description"); desc != "" { - mp.Description = desc - } - - m.Properties[name] = mp - } - if ignore := stag.Get("ignore"); ignore != "" { - continue - } - } else { - for _, pkg := range astPkgs { - for _, fl := range pkg.Files { - for nameOfObj, obj := range fl.Scope.Objects { - if obj.Name == fmt.Sprint(field.Type) { - parseObject(obj, nameOfObj, m, realTypes, astPkgs, pkg.Name) - } - } - } - } - } - } - } -} - -func typeAnalyser(f *ast.Field) (isSlice bool, realType, swaggerType string) { - if arr, ok := f.Type.(*ast.ArrayType); ok { - if isBasicType(fmt.Sprint(arr.Elt)) { - return true, fmt.Sprintf("[]%v", arr.Elt), basicTypes[fmt.Sprint(arr.Elt)] - } - if mp, ok := arr.Elt.(*ast.MapType); ok { - return false, fmt.Sprintf("map[%v][%v]", mp.Key, mp.Value), "object" - } - if star, ok := arr.Elt.(*ast.StarExpr); ok { - return true, fmt.Sprint(star.X), "object" - } - return true, fmt.Sprint(arr.Elt), "object" - } - switch t := f.Type.(type) { - case *ast.StarExpr: - basicType := fmt.Sprint(t.X) - if k, ok := basicTypes[basicType]; ok { - return false, basicType, k - } - return false, basicType, "object" - case *ast.MapType: - val := fmt.Sprintf("%v", t.Value) - if isBasicType(val) { - return false, "map", basicTypes[val] - } - return false, val, "object" - } - basicType := fmt.Sprint(f.Type) - if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject { - basicType = object - } - if k, ok := basicTypes[basicType]; ok { - return false, basicType, k - } - return false, basicType, "object" -} - -func isBasicType(Type string) bool { - if _, ok := basicTypes[Type]; ok { - return true - } - return false -} - -// append models -func appendModels(pkgpath, controllerName string, realTypes []string) { - for _, realType := range realTypes { - if realType != "" && !isBasicType(strings.TrimLeft(realType, "[]")) && - !strings.HasPrefix(realType, "map") && !strings.HasPrefix(realType, "&") { - if _, ok := modelsList[pkgpath+controllerName][realType]; ok { - continue - } - _, mod, newRealTypes := getModel(realType) - modelsList[pkgpath+controllerName][realType] = mod - appendModels(pkgpath, controllerName, newRealTypes) - } - } -} - func getSecurity(t string) (security map[string][]string) { security = make(map[string][]string) p := getparams(strings.TrimSpace(t[len("@Security"):])) diff --git a/main.go b/main.go index beccb06..ce4fbaf 100644 --- a/main.go +++ b/main.go @@ -21,13 +21,10 @@ import ( "github.com/beego/bee/cmd" "github.com/beego/bee/cmd/commands" "github.com/beego/bee/config" - "github.com/beego/bee/generate/swaggergen" "github.com/beego/bee/utils" ) func main() { - currentpath, _ := os.Getwd() - flag.Usage = cmd.Usage flag.Parse() log.SetFlags(0) @@ -61,12 +58,6 @@ func main() { config.LoadConfig() - // Check if current directory is inside the GOPATH, - // if so parse the packages inside it. - if utils.IsInGOPATH(currentpath) && cmd.IfGenerateDocs(c.Name(), args) { - swaggergen.ParsePackagesFromDir(currentpath) - } - os.Exit(c.Run(c, args)) return }