diff --git a/generate/swaggergen/go_docs_test.go b/generate/swaggergen/go_docs_test.go new file mode 100644 index 0000000..ddfbdb7 --- /dev/null +++ b/generate/swaggergen/go_docs_test.go @@ -0,0 +1,117 @@ +package swaggergen + +import ( + "go/ast" + "go/build" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +//package model +// +//import ( +// "github.com/shopspring/decimal" +//) +// +//type Object struct{ +// Total decimal.Decimal +//} +func TestCheckAndLoadPackageOnGoMod(t *testing.T) { + var ( + pkgName = "decimal" + pkgImportPath = "github.com/shopspring/decimal" + ) + + defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) + os.Setenv("GO111MODULE", "on") + + imports := []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: pkgImportPath, + }, + }, + } + checkAndLoadPackage(imports, "decimal.Decimal", "model") + if len(astPkgs) == 0 { + t.Fatalf("failed to load module: %s", pkgImportPath) + } + notLoadFlag := true + for _, v := range astPkgs { + if v.Name == pkgName { + notLoadFlag = false + } + } + if notLoadFlag { + t.Fatalf("failed to load module: %s", pkgImportPath) + } +} + +//package model +// +//import ( +//"example.com/comm" +//) +// +//type Object struct { +// Total comm.Common +//} +func TestCheckAndLoadPackageOnGoPath(t *testing.T) { + var ( + pkgName = "comm" + pkgImportPath = "example.com/comm" + + testCommPkg = ` +package comm + +type Common struct { + Code string + Error string +} +` + ) + + gopath, err := ioutil.TempDir("", "gobuild-gopath") + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(gopath) + + if err := os.MkdirAll(filepath.Join(gopath, "src/example.com/comm"), 0777); err != nil { + t.Fatal(err) + } + + if err := ioutil.WriteFile(filepath.Join(gopath, "src/example.com/comm/comm.go"), []byte(testCommPkg), 0666); err != nil { + t.Fatal(err) + } + + defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) + os.Setenv("GO111MODULE", "off") + defer os.Setenv("GOPATH", os.Getenv("GOPATH")) + os.Setenv("GOPATH", gopath) + build.Default.GOPATH = gopath + + imports := []*ast.ImportSpec{ + { + Path: &ast.BasicLit{ + Value: pkgImportPath, + }, + }, + } + checkAndLoadPackage(imports, "comm.Common", "model") + if len(astPkgs) == 0 { + t.Fatalf("failed to load module: %s", pkgImportPath) + } + notLoadFlag := true + for _, v := range astPkgs { + if v.Name == pkgName { + notLoadFlag = false + } + } + if notLoadFlag { + t.Fatalf("failed to load module: %s", pkgImportPath) + } +}