This commit is contained in:
weiyang 2017-08-21 10:54:18 +00:00 committed by GitHub
commit 80360d5fe4
2 changed files with 55 additions and 316 deletions

View File

@ -24,19 +24,19 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"unicode"
"gopkg.in/yaml.v2"
"github.com/ghodss/yaml"
"github.com/astaxie/beego/swagger"
"github.com/astaxie/beego/utils"
beeLogger "github.com/beego/bee/logger"
bu "github.com/beego/bee/utils"
"github.com/go-openapi/spec"
"github.com/wy-z/tspec/tspec"
)
const (
@ -47,13 +47,32 @@ const (
aform = "multipart/form-data"
)
// Swagger redefines definitions
type Swagger struct {
swagger.Swagger
Definitions spec.Definitions `json:"definitions,omitempty" yaml:"definitions,omitempty"`
}
var pkgCache map[string]struct{} //pkg:controller:function:comments comments: key:value
var controllerComments map[string]string
var importlist map[string]string
var controllerList map[string]map[string]*swagger.Item //controllername Paths items
var modelsList map[string]map[string]swagger.Schema
var rootapi swagger.Swagger
var astPkgs []*ast.Package
var rootapi Swagger
var tparser *tspec.Parser
var controllerPkg *ast.Package
func parseModel(pkg *ast.Package, typeStr string) (typeTitle string, err error) {
if pkg == nil {
panic("pkg can not be nil")
}
schema, err := tparser.Parse(pkg, typeStr)
if err != nil {
return
}
typeTitle = schema.Title
return
}
// refer to builtin.go
var basicTypes = map[string]string{
@ -80,70 +99,15 @@ var basicTypes = map[string]string{
"time.Time": "string:string",
}
var stdlibObject = map[string]string{
"&{time Time}": "time.Time",
}
func init() {
pkgCache = make(map[string]struct{})
controllerComments = make(map[string]string)
importlist = make(map[string]string)
controllerList = make(map[string]map[string]*swagger.Item)
modelsList = make(map[string]map[string]swagger.Schema)
astPkgs = make([]*ast.Package, 0)
}
func ParsePackagesFromDir(dirpath string) {
c := make(chan error)
go func() {
filepath.Walk(dirpath, func(fpath string, fileInfo os.FileInfo, err error) error {
if err != nil {
return nil
}
if !fileInfo.IsDir() {
return nil
}
// 7 is length of 'vendor' (6) + length of file path separator (1)
// so we skip dir 'vendor' which is directly under dirpath
if !(len(fpath) == len(dirpath)+7 && strings.HasSuffix(fpath, "vendor")) &&
!strings.Contains(fpath, "tests") &&
!(len(fpath) > len(dirpath) && fpath[len(dirpath)+1] == '.') {
err = parsePackageFromDir(fpath)
if err != nil {
// Send the error to through the channel and continue walking
c <- fmt.Errorf("Error while parsing directory: %s", err.Error())
return nil
}
}
return nil
})
close(c)
}()
for err := range c {
beeLogger.Log.Warnf("%s", err)
}
}
func parsePackageFromDir(path string) error {
fileSet := token.NewFileSet()
folderPkgs, err := parser.ParseDir(fileSet, path, func(info os.FileInfo) bool {
name := info.Name()
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}, parser.ParseComments)
if err != nil {
return err
}
for _, v := range folderPkgs {
astPkgs = append(astPkgs, v)
}
return nil
tparser = tspec.NewParser()
}
// GenerateDocs ...
func GenerateDocs(curpath string) {
fset := token.NewFileSet()
@ -307,6 +271,8 @@ func GenerateDocs(curpath string) {
}
}
}
rootapi.Definitions = tparser.Definitions()
os.Mkdir(path.Join(curpath, "swagger"), 0755)
fd, err := os.Create(path.Join(curpath, "swagger", "swagger.json"))
if err != nil {
@ -319,9 +285,14 @@ func GenerateDocs(curpath string) {
defer fdyml.Close()
defer fd.Close()
dt, err := json.MarshalIndent(rootapi, "", " ")
dtyml, erryml := yaml.Marshal(rootapi)
if err != nil || erryml != nil {
panic(err)
if err != nil {
msg := fmt.Sprintf("failed to marshal api doc: %s", err)
panic(msg)
}
dtyml, erryml := yaml.JSONToYAML(dt)
if erryml != nil {
msg := fmt.Sprintf("failed to convert json bytes to yaml bytes: %s", erryml)
panic(msg)
}
_, err = fd.Write(dt)
_, erryml = fdyml.Write(dtyml)
@ -442,6 +413,10 @@ func analyseControllerPkg(vendorPath, localName, pkgpath string) {
beeLogger.Log.Fatalf("Error while parsing dir at '%s': %s", pkgpath, err)
}
for _, pkg := range astPkgs {
if pkg.Name == "controllers" {
controllerPkg = pkg
}
for _, fl := range pkg.Files {
for _, d := range fl.Decls {
switch specDecl := d.(type) {
@ -557,13 +532,11 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
schema.Type = typeFormat[0]
schema.Format = typeFormat[1]
} else {
m, mod, realTypes := getModel(schemaName)
schema.Ref = "#/definitions/" + m
if _, ok := modelsList[pkgpath+controllerName]; !ok {
modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema)
typeTitle, err := parseModel(controllerPkg, schemaName)
if err != nil {
beeLogger.Log.Fatalf("failed to parse model %s: %s", schemaName, err)
}
modelsList[pkgpath+controllerName][schemaName] = mod
appendModels(pkgpath, controllerName, realTypes)
schema.Ref = "#/definitions/" + typeTitle
}
if isArray {
rs.Schema = &swagger.Schema{
@ -613,15 +586,13 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
pp := strings.Split(p[2], ".")
typ := pp[len(pp)-1]
if len(pp) >= 2 {
m, mod, realTypes := getModel(p[2])
typeTitle, err := parseModel(controllerPkg, p[2])
if err != nil {
beeLogger.Log.Fatalf("failed to parse model %s: %s", p[2], err)
}
para.Schema = &swagger.Schema{
Ref: "#/definitions/" + m,
Ref: "#/definitions/" + typeTitle,
}
if _, ok := modelsList[pkgpath+controllerName]; !ok {
modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema)
}
modelsList[pkgpath+controllerName][typ] = mod
appendModels(pkgpath, controllerName, realTypes)
} else {
if typ == "auto" {
typ = paramType
@ -754,15 +725,13 @@ func setParamType(para *swagger.Parameter, typ string, pkgpath, controllerName s
paraType = typeFormat[0]
paraFormat = typeFormat[1]
} else {
m, mod, realTypes := getModel(typ)
typeTitle, err := parseModel(controllerPkg, typ)
if err != nil {
beeLogger.Log.Fatalf("failed to parse model %s: %s", typ, err)
}
para.Schema = &swagger.Schema{
Ref: "#/definitions/" + m,
Ref: "#/definitions/" + typeTitle,
}
if _, ok := modelsList[pkgpath+controllerName]; !ok {
modelsList[pkgpath+controllerName] = make(map[string]swagger.Schema)
}
modelsList[pkgpath+controllerName][typ] = mod
appendModels(pkgpath, controllerName, realTypes)
}
if isArray {
para.Type = "array"
@ -857,227 +826,6 @@ func getparams(str string) []string {
return r
}
func getModel(str string) (objectname string, m swagger.Schema, realTypes []string) {
strs := strings.Split(str, ".")
objectname = strs[len(strs)-1]
packageName := ""
m.Type = "object"
for _, pkg := range astPkgs {
if strs[0] == pkg.Name {
for _, fl := range pkg.Files {
for k, d := range fl.Scope.Objects {
if d.Kind == ast.Typ {
if k != objectname {
continue
}
packageName = pkg.Name
parseObject(d, k, &m, &realTypes, astPkgs, pkg.Name)
}
}
}
}
}
if m.Title == "" {
beeLogger.Log.Warnf("Cannot find the object: %s", str)
// TODO remove when all type have been supported
//os.Exit(1)
}
if len(rootapi.Definitions) == 0 {
rootapi.Definitions = make(map[string]swagger.Schema)
}
objectname = packageName + "." + objectname
rootapi.Definitions[objectname] = m
return
}
func parseObject(d *ast.Object, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) {
ts, ok := d.Decl.(*ast.TypeSpec)
if !ok {
beeLogger.Log.Fatalf("Unknown type without TypeSec: %v\n", d)
}
// TODO support other types, such as `ArrayType`, `MapType`, `InterfaceType` etc...
st, ok := ts.Type.(*ast.StructType)
if !ok {
return
}
m.Title = k
if st.Fields.List != nil {
m.Properties = make(map[string]swagger.Propertie)
for _, field := range st.Fields.List {
isSlice, realType, sType := typeAnalyser(field)
if (isSlice && isBasicType(realType)) || sType == "object" {
if len(strings.Split(realType, " ")) > 1 {
realType = strings.Replace(realType, " ", ".", -1)
realType = strings.Replace(realType, "&", "", -1)
realType = strings.Replace(realType, "{", "", -1)
realType = strings.Replace(realType, "}", "", -1)
} else {
realType = packageName + "." + realType
}
}
*realTypes = append(*realTypes, realType)
mp := swagger.Propertie{}
if isSlice {
mp.Type = "array"
if isBasicType(strings.Replace(realType, "[]", "", -1)) {
typeFormat := strings.Split(sType, ":")
mp.Items = &swagger.Propertie{
Type: typeFormat[0],
Format: typeFormat[1],
}
} else {
mp.Items = &swagger.Propertie{
Ref: "#/definitions/" + realType,
}
}
} else {
if sType == "object" {
mp.Ref = "#/definitions/" + realType
} else if isBasicType(realType) {
typeFormat := strings.Split(sType, ":")
mp.Type = typeFormat[0]
mp.Format = typeFormat[1]
} else if realType == "map" {
typeFormat := strings.Split(sType, ":")
mp.AdditionalProperties = &swagger.Propertie{
Type: typeFormat[0],
Format: typeFormat[1],
}
}
}
if field.Names != nil {
// set property name as field name
var name = field.Names[0].Name
// if no tag skip tag processing
if field.Tag == nil {
m.Properties[name] = mp
continue
}
var tagValues []string
stag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
defaultValue := stag.Get("doc")
if defaultValue != "" {
r, _ := regexp.Compile(`default\((.*)\)`)
if r.MatchString(defaultValue) {
res := r.FindStringSubmatch(defaultValue)
mp.Default = str2RealType(res[1], realType)
} else {
beeLogger.Log.Warnf("Invalid default value: %s", defaultValue)
}
}
tag := stag.Get("json")
if tag != "" {
tagValues = strings.Split(tag, ",")
}
// dont add property if json tag first value is "-"
if len(tagValues) == 0 || tagValues[0] != "-" {
// set property name to the left most json tag value only if is not omitempty
if len(tagValues) > 0 && tagValues[0] != "omitempty" {
name = tagValues[0]
}
if thrifttag := stag.Get("thrift"); thrifttag != "" {
ts := strings.Split(thrifttag, ",")
if ts[0] != "" {
name = ts[0]
}
}
if required := stag.Get("required"); required != "" {
m.Required = append(m.Required, name)
}
if desc := stag.Get("description"); desc != "" {
mp.Description = desc
}
m.Properties[name] = mp
}
if ignore := stag.Get("ignore"); ignore != "" {
continue
}
} else {
for _, pkg := range astPkgs {
for _, fl := range pkg.Files {
for nameOfObj, obj := range fl.Scope.Objects {
if obj.Name == fmt.Sprint(field.Type) {
parseObject(obj, nameOfObj, m, realTypes, astPkgs, pkg.Name)
}
}
}
}
}
}
}
}
func typeAnalyser(f *ast.Field) (isSlice bool, realType, swaggerType string) {
if arr, ok := f.Type.(*ast.ArrayType); ok {
if isBasicType(fmt.Sprint(arr.Elt)) {
return true, fmt.Sprintf("[]%v", arr.Elt), basicTypes[fmt.Sprint(arr.Elt)]
}
if mp, ok := arr.Elt.(*ast.MapType); ok {
return false, fmt.Sprintf("map[%v][%v]", mp.Key, mp.Value), "object"
}
if star, ok := arr.Elt.(*ast.StarExpr); ok {
return true, fmt.Sprint(star.X), "object"
}
return true, fmt.Sprint(arr.Elt), "object"
}
switch t := f.Type.(type) {
case *ast.StarExpr:
basicType := fmt.Sprint(t.X)
if k, ok := basicTypes[basicType]; ok {
return false, basicType, k
}
return false, basicType, "object"
case *ast.MapType:
val := fmt.Sprintf("%v", t.Value)
if isBasicType(val) {
return false, "map", basicTypes[val]
}
return false, val, "object"
}
basicType := fmt.Sprint(f.Type)
if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
basicType = object
}
if k, ok := basicTypes[basicType]; ok {
return false, basicType, k
}
return false, basicType, "object"
}
func isBasicType(Type string) bool {
if _, ok := basicTypes[Type]; ok {
return true
}
return false
}
// append models
func appendModels(pkgpath, controllerName string, realTypes []string) {
for _, realType := range realTypes {
if realType != "" && !isBasicType(strings.TrimLeft(realType, "[]")) &&
!strings.HasPrefix(realType, "map") && !strings.HasPrefix(realType, "&") {
if _, ok := modelsList[pkgpath+controllerName][realType]; ok {
continue
}
_, mod, newRealTypes := getModel(realType)
modelsList[pkgpath+controllerName][realType] = mod
appendModels(pkgpath, controllerName, newRealTypes)
}
}
}
func getSecurity(t string) (security map[string][]string) {
security = make(map[string][]string)
p := getparams(strings.TrimSpace(t[len("@Security"):]))

View File

@ -21,13 +21,10 @@ import (
"github.com/beego/bee/cmd"
"github.com/beego/bee/cmd/commands"
"github.com/beego/bee/config"
"github.com/beego/bee/generate/swaggergen"
"github.com/beego/bee/utils"
)
func main() {
currentpath, _ := os.Getwd()
flag.Usage = cmd.Usage
flag.Parse()
log.SetFlags(0)
@ -61,12 +58,6 @@ func main() {
config.LoadConfig()
// Check if current directory is inside the GOPATH,
// if so parse the packages inside it.
if utils.IsInGOPATH(currentpath) && cmd.IfGenerateDocs(c.Name(), args) {
swaggergen.ParsePackagesFromDir(currentpath)
}
os.Exit(c.Run(c, args))
return
}