diff --git a/.travis.yml b/.travis.yml index 9c5f4d7..66d2dd9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ script: - cd $(dirname `dirname $(pwd)`)/beego/bee - export GO111MODULE="on" - go mod download + - go test -coverprofile=coverage.txt -covermode=atomic ./... - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - go list ./... | grep -v /vendor/ | grep -v /pkg/mod/ - go vet $(go list ./... | grep -v /vendor/ | grep -v /pkg/mod/ ) diff --git a/generate/swaggergen/g_docs.go b/generate/swaggergen/g_docs.go index 454269c..25a0672 100644 --- a/generate/swaggergen/g_docs.go +++ b/generate/swaggergen/g_docs.go @@ -35,6 +35,8 @@ import ( yaml "gopkg.in/yaml.v2" + bu "github.com/beego/bee/v2/utils" + beeLogger "github.com/beego/bee/v2/logger" "github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/server/web/swagger" @@ -61,6 +63,7 @@ var controllerList map[string]map[string]*swagger.Item //controllername Paths it var modelsList map[string]map[string]swagger.Schema var rootapi swagger.Swagger var astPkgs []*ast.Package +var pkgLoadedCache map[string]struct{} // refer to builtin.go var basicTypes = map[string]string{ @@ -100,6 +103,7 @@ func init() { controllerList = make(map[string]map[string]*swagger.Item) modelsList = make(map[string]map[string]swagger.Schema) astPkgs = make([]*ast.Package, 0) + pkgLoadedCache = make(map[string]struct{}) } // parsePackagesFromDir parses packages from a given directory @@ -152,6 +156,16 @@ func parsePackageFromDir(path string) error { astPkgs = append(astPkgs, v) } + if len(folderPkgs) != 0 { + workPath := bu.GetBeeWorkPath() + parentPath := filepath.Dir(workPath) + rel, err := filepath.Rel(parentPath, path) + if err != nil { + return err + } + pkgLoadedCache[rel] = struct{}{} + } + return nil } @@ -933,7 +947,8 @@ L: // Still searching for the right object continue } - parseObject(d, k, &m, &realTypes, astPkgs, packageName) + + parseObject(fl.Imports, d, k, &m, &realTypes, astPkgs, packageName) // When we've found the correct object, we can stop searching break L @@ -958,7 +973,7 @@ L: return str, m, realTypes } -func parseObject(d *ast.Object, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) { +func parseObject(imports []*ast.ImportSpec, d *ast.Object, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) { ts, ok := d.Decl.(*ast.TypeSpec) if !ok { beeLogger.Log.Fatalf("Unknown type without TypeSec: %v", d) @@ -983,7 +998,7 @@ func parseObject(d *ast.Object, k string, m *swagger.Schema, realTypes *[]string case *ast.Ident: parseIdent(t, k, m, astPkgs) case *ast.StructType: - parseStruct(t, k, m, realTypes, astPkgs, packageName) + parseStruct(imports, t, k, m, realTypes, astPkgs, packageName) } } @@ -1068,7 +1083,7 @@ func parseIdent(st *ast.Ident, k string, m *swagger.Schema, astPkgs []*ast.Packa } -func parseStruct(st *ast.StructType, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) { +func parseStruct(imports []*ast.ImportSpec, st *ast.StructType, k string, m *swagger.Schema, realTypes *[]string, astPkgs []*ast.Package, packageName string) { m.Title = k if st.Fields.List != nil { m.Properties = make(map[string]swagger.Propertie) @@ -1084,6 +1099,11 @@ func parseStruct(st *ast.StructType, k string, m *swagger.Schema, realTypes *[]s realType = packageName + "." + realType } } + + if !isBasicType(realType) && sType == astTypeObject { + checkAndLoadPackage(imports, realType, packageName) + } + *realTypes = append(*realTypes, realType) mp := swagger.Propertie{} isObject := false @@ -1204,7 +1224,7 @@ func parseStruct(st *ast.StructType, k string, m *swagger.Schema, realTypes *[]s for _, fl := range pkg.Files { for nameOfObj, obj := range fl.Scope.Objects { if pkg.Name+"."+obj.Name == realType { - parseObject(obj, nameOfObj, nm, realTypes, astPkgs, pkg.Name) + parseObject(imports, obj, nameOfObj, nm, realTypes, astPkgs, pkg.Name) } } } @@ -1340,3 +1360,70 @@ func str2RealType(s string, typ string) interface{} { return ret } + +func checkAndLoadPackage(imports []*ast.ImportSpec, realType, curPkgName string) { + arr := strings.Split(realType, ".") + if len(arr) != 2 { + return + } + objectPkgName := arr[0] + if objectPkgName == curPkgName { + return + } + pkgPath := "" + for _, im := range imports { + importPath := "" + if im.Path != nil { + importPath = strings.Trim(im.Path.Value, `"`) + } + + if importPath == "" { + continue + } + + if im.Name != nil && im.Name.Name == objectPkgName { + pkgPath = importPath + break + } + + _, pkgName := filepath.Split(importPath) + if pkgName == objectPkgName { + pkgPath = importPath + break + } + } + + if pkgPath == "" { + beeLogger.Log.Warnf("%s missing import package", realType) + return + } + + if isSystemPackage(pkgPath) { + return + } + if _, ok := pkgLoadedCache[pkgPath]; ok { + return + } + + pkg, err := build.Default.Import(pkgPath, ".", build.FindOnly) + if err != nil { + beeLogger.Log.Warnf("Package %s cannot be imported, err:%v", pkgPath, err) + return + } + pkgRealpath := pkg.Dir + + fileSet := token.NewFileSet() + pkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { + name := info.Name() + return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") + }, parser.ParseComments) + if err != nil { + beeLogger.Log.Warnf("Error while parsing dir at '%s': %s", pkgRealpath, err) + } + + for _, pkg := range pkgs { + astPkgs = append(astPkgs, pkg) + } + + pkgLoadedCache[pkgPath] = struct{}{} +} diff --git a/generate/swaggergen/go_docs_test.go b/generate/swaggergen/go_docs_test.go new file mode 100644 index 0000000..6b3aec9 --- /dev/null +++ b/generate/swaggergen/go_docs_test.go @@ -0,0 +1,223 @@ +// Copyright 2013 bee authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package swaggergen + +import ( + "go/ast" + "go/build" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +//package model +// +//import ( +//"sync" +// +//"example.com/pkgnotexist" +//"github.com/shopspring/decimal" +//) +// +//type Object struct { +// Field1 decimal.Decimal +// Field2 pkgnotexist.TestType +// Field3 sync.Map +//} +func TestCheckAndLoadPackageOnGoMod(t *testing.T) { + defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) + os.Setenv("GO111MODULE", "on") + + testCases := []struct { + pkgName string + pkgImportPath string + imports []*ast.ImportSpec + realType string + curPkgName string + expected bool + }{ + { + pkgName: "decimal", + pkgImportPath: "github.com/shopspring/decimal", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "github.com/shopspring/decimal", + }, + }, + }, + realType: "decimal.Decimal", + curPkgName: "model", + expected: true, + }, + { + pkgName: "pkgnotexist", + pkgImportPath: "example.com/pkgnotexist", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "example.com/pkgnotexist", + }, + }, + }, + realType: "pkgnotexist.TestType", + curPkgName: "model", + expected: false, + }, + { + pkgName: "sync", + pkgImportPath: "sync", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "sync", + }, + }, + }, + realType: "sync.Map", + curPkgName: "model", + expected: false, + }, + } + + for _, test := range testCases { + checkAndLoadPackage(test.imports, test.realType, test.curPkgName) + result := false + for _, v := range astPkgs { + if v.Name == test.pkgName { + result = true + break + } + } + if test.expected != result { + t.Fatalf("load module error, expected: %v, result: %v", test.expected, result) + } + } +} + +//package model +// +//import ( +//"sync" +// +//"example.com/comm" +//"example.com/pkgnotexist" +//) +// +//type Object struct { +// Field1 comm.Common +// Field2 pkgnotexist.TestType +// Field3 sync.Map +//} +func TestCheckAndLoadPackageOnGoPath(t *testing.T) { + var ( + testCommPkg = ` +package comm + +type Common struct { + Code string + Error string +} +` + ) + + gopath, err := ioutil.TempDir("", "gobuild-gopath") + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(gopath) + + if err := os.MkdirAll(filepath.Join(gopath, "src/example.com/comm"), 0777); err != nil { + t.Fatal(err) + } + + if err := ioutil.WriteFile(filepath.Join(gopath, "src/example.com/comm/comm.go"), []byte(testCommPkg), 0666); err != nil { + t.Fatal(err) + } + + defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) + os.Setenv("GO111MODULE", "off") + defer os.Setenv("GOPATH", os.Getenv("GOPATH")) + os.Setenv("GOPATH", gopath) + build.Default.GOPATH = gopath + + testCases := []struct { + pkgName string + pkgImportPath string + imports []*ast.ImportSpec + realType string + curPkgName string + expected bool + }{ + { + pkgName: "comm", + pkgImportPath: "example.com/comm", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "example.com/comm", + }, + }, + }, + realType: "comm.Common", + curPkgName: "model", + expected: true, + }, + { + pkgName: "pkgnotexist", + pkgImportPath: "example.com/pkgnotexist", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "example.com/pkgnotexist", + }, + }, + }, + realType: "pkgnotexist.TestType", + curPkgName: "model", + expected: false, + }, + { + pkgName: "sync", + pkgImportPath: "sync", + imports: []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: "sync", + }, + }, + }, + realType: "sync.Map", + curPkgName: "model", + expected: false, + }, + } + + for _, test := range testCases { + checkAndLoadPackage(test.imports, test.realType, test.curPkgName) + result := false + for _, v := range astPkgs { + if v.Name == test.pkgName { + result = true + break + } + } + if test.expected != result { + t.Fatalf("load module error, expected: %v, result: %v", test.expected, result) + } + } +}