// Copyright 2013 bee authors // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain // a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. package swaggergen import ( _ "github.com/shopspring/decimal" "go/ast" "go/build" "io/ioutil" "os" "path/filepath" "testing" ) // package model // // import ( // "sync" // // "example.com/pkgnotexist" // "github.com/shopspring/decimal" // ) // // type Object struct { // Field1 decimal.Decimal // Field2 pkgnotexist.TestType // Field3 sync.Map // } func TestCheckAndLoadPackageOnGoMod(t *testing.T) { defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) os.Setenv("GO111MODULE", "on") testCases := []struct { pkgName string pkgImportPath string imports []*ast.ImportSpec realType string curPkgName string expected bool }{ { pkgName: "decimal", pkgImportPath: "github.com/shopspring/decimal", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "github.com/shopspring/decimal", }, }, }, realType: "decimal.Decimal", curPkgName: "model", expected: true, }, { pkgName: "pkgnotexist", pkgImportPath: "example.com/pkgnotexist", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "example.com/pkgnotexist", }, }, }, realType: "pkgnotexist.TestType", curPkgName: "model", expected: false, }, { pkgName: "sync", pkgImportPath: "sync", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "sync", }, }, }, realType: "sync.Map", curPkgName: "model", expected: false, }, } for _, test := range testCases { checkAndLoadPackage(test.imports, test.realType, test.curPkgName) result := false for _, v := range astPkgs { if v.Name == test.pkgName { result = true break } } if test.expected != result { t.Fatalf("load module error, expected: %v, result: %v", test.expected, result) } } } // package model // // import ( // "sync" // // "example.com/comm" // "example.com/pkgnotexist" // ) // // type Object struct { // Field1 comm.Common // Field2 pkgnotexist.TestType // Field3 sync.Map // } func TestCheckAndLoadPackageOnGoPath(t *testing.T) { var ( testCommPkg = ` package comm type Common struct { Code string Error string } ` ) gopath, err := ioutil.TempDir("", "gobuild-gopath") if err != nil { t.Fatal(err) } defer os.RemoveAll(gopath) if err := os.MkdirAll(filepath.Join(gopath, "src/example.com/comm"), 0777); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(filepath.Join(gopath, "src/example.com/comm/comm.go"), []byte(testCommPkg), 0666); err != nil { t.Fatal(err) } defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE")) os.Setenv("GO111MODULE", "off") defer os.Setenv("GOPATH", os.Getenv("GOPATH")) os.Setenv("GOPATH", gopath) build.Default.GOPATH = gopath testCases := []struct { pkgName string pkgImportPath string imports []*ast.ImportSpec realType string curPkgName string expected bool }{ { pkgName: "comm", pkgImportPath: "example.com/comm", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "example.com/comm", }, }, }, realType: "comm.Common", curPkgName: "model", expected: true, }, { pkgName: "pkgnotexist", pkgImportPath: "example.com/pkgnotexist", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "example.com/pkgnotexist", }, }, }, realType: "pkgnotexist.TestType", curPkgName: "model", expected: false, }, { pkgName: "sync", pkgImportPath: "sync", imports: []*ast.ImportSpec{ { Path: &ast.BasicLit{ Value: "sync", }, }, }, realType: "sync.Map", curPkgName: "model", expected: false, }, } for _, test := range testCases { checkAndLoadPackage(test.imports, test.realType, test.curPkgName) result := false for _, v := range astPkgs { if v.Name == test.pkgName { result = true break } } if test.expected != result { t.Fatalf("load module error, expected: %v, result: %v", test.expected, result) } } }