add unit test

This commit is contained in:
John 2021-02-06 20:10:12 +08:00
parent 02a66c7bd7
commit d995fc5328
1 changed files with 117 additions and 0 deletions

View File

@ -0,0 +1,117 @@
package swaggergen
import (
"go/ast"
"go/build"
"io/ioutil"
"os"
"path/filepath"
"testing"
)
//package model
//
//import (
// "github.com/shopspring/decimal"
//)
//
//type Object struct{
// Total decimal.Decimal
//}
func TestCheckAndLoadPackageOnGoMod(t *testing.T) {
var (
pkgName = "decimal"
pkgImportPath = "github.com/shopspring/decimal"
)
defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE"))
os.Setenv("GO111MODULE", "on")
imports := []*ast.ImportSpec{
{
Path: &ast.BasicLit{
Value: pkgImportPath,
},
},
}
checkAndLoadPackage(imports, "decimal.Decimal", "model")
if len(astPkgs) == 0 {
t.Fatalf("failed to load module: %s", pkgImportPath)
}
notLoadFlag := true
for _, v := range astPkgs {
if v.Name == pkgName {
notLoadFlag = false
}
}
if notLoadFlag {
t.Fatalf("failed to load module: %s", pkgImportPath)
}
}
//package model
//
//import (
//"example.com/comm"
//)
//
//type Object struct {
// Total comm.Common
//}
func TestCheckAndLoadPackageOnGoPath(t *testing.T) {
var (
pkgName = "comm"
pkgImportPath = "example.com/comm"
testCommPkg = `
package comm
type Common struct {
Code string
Error string
}
`
)
gopath, err := ioutil.TempDir("", "gobuild-gopath")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(gopath)
if err := os.MkdirAll(filepath.Join(gopath, "src/example.com/comm"), 0777); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(gopath, "src/example.com/comm/comm.go"), []byte(testCommPkg), 0666); err != nil {
t.Fatal(err)
}
defer os.Setenv("GO111MODULE", os.Getenv("GO111MODULE"))
os.Setenv("GO111MODULE", "off")
defer os.Setenv("GOPATH", os.Getenv("GOPATH"))
os.Setenv("GOPATH", gopath)
build.Default.GOPATH = gopath
imports := []*ast.ImportSpec{
{
Path: &ast.BasicLit{
Value: pkgImportPath,
},
},
}
checkAndLoadPackage(imports, "comm.Common", "model")
if len(astPkgs) == 0 {
t.Fatalf("failed to load module: %s", pkgImportPath)
}
notLoadFlag := true
for _, v := range astPkgs {
if v.Name == pkgName {
notLoadFlag = false
}
}
if notLoadFlag {
t.Fatalf("failed to load module: %s", pkgImportPath)
}
}