Merge pull request #52 from ZhengYang/master

Import logic for code gen
This commit is contained in:
astaxie 2014-08-19 17:22:49 +08:00
commit 124876e271
3 changed files with 121 additions and 59 deletions

View File

@ -20,6 +20,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path" "path"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
@ -73,11 +74,12 @@ var typeMapping = map[string]string{
// Table represent a table in a database // Table represent a table in a database
type Table struct { type Table struct {
Name string Name string
Pk string Pk string
Uk []string Uk []string
Fk map[string]*ForeignKey Fk map[string]*ForeignKey
Columns []*Column Columns []*Column
ImportTimePkg bool
} }
// Column reprsents a column for a table // Column reprsents a column for a table
@ -191,7 +193,7 @@ func (tag *OrmTag) String() string {
return fmt.Sprintf("`orm:\"%s\"`", strings.Join(ormOptions, ";")) return fmt.Sprintf("`orm:\"%s\"`", strings.Join(ormOptions, ";"))
} }
func generateAppcode(driver string, connStr string, level string, tables string, currpath string) { func generateAppcode(driver, connStr, level, tables, currpath string) {
var mode byte var mode byte
if level == "1" { if level == "1" {
mode = O_MODEL mode = O_MODEL
@ -216,7 +218,7 @@ func generateAppcode(driver string, connStr string, level string, tables string,
// Generate takes table, column and foreign key information from database connection // Generate takes table, column and foreign key information from database connection
// and generate corresponding golang source files // and generate corresponding golang source files
func gen(dbms string, connStr string, mode byte, selectedTableNames map[string]bool, currpath string) { func gen(dbms, connStr string, mode byte, selectedTableNames map[string]bool, currpath string) {
db, err := sql.Open(dbms, connStr) db, err := sql.Open(dbms, connStr)
if err != nil { if err != nil {
ColorLog("[ERRO] Could not connect to %s: %s\n", dbms, connStr) ColorLog("[ERRO] Could not connect to %s: %s\n", dbms, connStr)
@ -231,7 +233,8 @@ func gen(dbms string, connStr string, mode byte, selectedTableNames map[string]b
mvcPath.ControllerPath = path.Join(currpath, "controllers") mvcPath.ControllerPath = path.Join(currpath, "controllers")
mvcPath.RouterPath = path.Join(currpath, "routers") mvcPath.RouterPath = path.Join(currpath, "routers")
createPaths(mode, mvcPath) createPaths(mode, mvcPath)
writeSourceFiles(tables, mode, mvcPath, selectedTableNames) pkgPath := getPackagePath(currpath)
writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames)
} }
// getTables gets a list table names in current database // getTables gets a list table names in current database
@ -398,6 +401,8 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) {
} else if columnDefault == "CURRENT_TIMESTAMP" { } else if columnDefault == "CURRENT_TIMESTAMP" {
tag.AutoNowAdd = true tag.AutoNowAdd = true
} }
// need to import time package
table.ImportTimePkg = true
} }
if isSQLDecimal(dataType) { if isSQLDecimal(dataType) {
tag.Digits, tag.Decimals = extractDecimal(columnType) tag.Digits, tag.Decimals = extractDecimal(columnType)
@ -425,18 +430,18 @@ func createPaths(mode byte, paths *MvcPath) {
// writeSourceFiles generates source files for model/controller/router // writeSourceFiles generates source files for model/controller/router
// It will wipe the following directories and recreate them:./models, ./controllers, ./routers // It will wipe the following directories and recreate them:./models, ./controllers, ./routers
// Newly geneated files will be inside these folders. // Newly geneated files will be inside these folders.
func writeSourceFiles(tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) { func writeSourceFiles(pkgPath string, tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) {
if (O_MODEL & mode) == O_MODEL { if (O_MODEL & mode) == O_MODEL {
ColorLog("[INFO] Creating model files...\n") ColorLog("[INFO] Creating model files...\n")
writeModelFiles(tables, paths.ModelPath, selectedTables) writeModelFiles(tables, paths.ModelPath, selectedTables)
} }
if (O_CONTROLLER & mode) == O_CONTROLLER { if (O_CONTROLLER & mode) == O_CONTROLLER {
ColorLog("[INFO] Creating controller files...\n") ColorLog("[INFO] Creating controller files...\n")
writeControllerFiles(tables, paths.ControllerPath, selectedTables) writeControllerFiles(tables, paths.ControllerPath, selectedTables, pkgPath)
} }
if (O_ROUTER & mode) == O_ROUTER { if (O_ROUTER & mode) == O_ROUTER {
ColorLog("[INFO] Creating router files...\n") ColorLog("[INFO] Creating router files...\n")
writeRouterFile(tables, paths.RouterPath, selectedTables) writeRouterFile(tables, paths.RouterPath, selectedTables, pkgPath)
} }
} }
@ -480,18 +485,24 @@ func writeModelFiles(tables []*Table, mPath string, selectedTables map[string]bo
} }
fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1) fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1)
fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1) fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1)
// if table contains time field, import time.Time package
timePkg := ""
if tb.ImportTimePkg {
timePkg = "\"time\"\n"
}
fileStr = strings.Replace(fileStr, "{{timePkg}}", timePkg, -1)
if _, err := f.WriteString(fileStr); err != nil { if _, err := f.WriteString(fileStr); err != nil {
ColorLog("[ERRO] Could not write model file to %s\n", fpath) ColorLog("[ERRO] Could not write model file to %s\n", fpath)
os.Exit(2) os.Exit(2)
} }
f.Close() f.Close()
ColorLog("[INFO] model => %s\n", fpath) ColorLog("[INFO] model => %s\n", fpath)
formatAndFixImports(fpath) formatSourceCode(fpath)
} }
} }
// writeControllerFiles generates controller files // writeControllerFiles generates controller files
func writeControllerFiles(tables []*Table, cPath string, selectedTables map[string]bool) { func writeControllerFiles(tables []*Table, cPath string, selectedTables map[string]bool, pkgPath string) {
for _, tb := range tables { for _, tb := range tables {
// if selectedTables map is not nil and this table is not selected, ignore it // if selectedTables map is not nil and this table is not selected, ignore it
if selectedTables != nil { if selectedTables != nil {
@ -526,18 +537,19 @@ func writeControllerFiles(tables []*Table, cPath string, selectedTables map[stri
} }
} }
fileStr := strings.Replace(CTRL_TPL, "{{ctrlName}}", camelCase(tb.Name), -1) fileStr := strings.Replace(CTRL_TPL, "{{ctrlName}}", camelCase(tb.Name), -1)
fileStr = strings.Replace(fileStr, "{{pkgPath}}", pkgPath, -1)
if _, err := f.WriteString(fileStr); err != nil { if _, err := f.WriteString(fileStr); err != nil {
ColorLog("[ERRO] Could not write controller file to %s\n", fpath) ColorLog("[ERRO] Could not write controller file to %s\n", fpath)
os.Exit(2) os.Exit(2)
} }
f.Close() f.Close()
ColorLog("[INFO] controller => %s\n", fpath) ColorLog("[INFO] controller => %s\n", fpath)
formatAndFixImports(fpath) formatSourceCode(fpath)
} }
} }
// writeRouterFile generates router file // writeRouterFile generates router file
func writeRouterFile(tables []*Table, rPath string, selectedTables map[string]bool) { func writeRouterFile(tables []*Table, rPath string, selectedTables map[string]bool, pkgPath string) {
var nameSpaces []string var nameSpaces []string
for _, tb := range tables { for _, tb := range tables {
// if selectedTables map is not nil and this table is not selected, ignore it // if selectedTables map is not nil and this table is not selected, ignore it
@ -557,8 +569,7 @@ func writeRouterFile(tables []*Table, rPath string, selectedTables map[string]bo
// add export controller // add export controller
fpath := path.Join(rPath, "router.go") fpath := path.Join(rPath, "router.go")
routerStr := strings.Replace(ROUTER_TPL, "{{nameSpaces}}", strings.Join(nameSpaces, ""), 1) routerStr := strings.Replace(ROUTER_TPL, "{{nameSpaces}}", strings.Join(nameSpaces, ""), 1)
_, projectName := path.Split(path.Dir(rPath)) routerStr = strings.Replace(routerStr, "{{pkgPath}}", pkgPath, 1)
routerStr = strings.Replace(routerStr, "{{projectName}}", projectName, 1)
var f *os.File var f *os.File
var err error var err error
if isExist(fpath) { if isExist(fpath) {
@ -586,13 +597,15 @@ func writeRouterFile(tables []*Table, rPath string, selectedTables map[string]bo
} }
f.Close() f.Close()
ColorLog("[INFO] router => %s\n", fpath) ColorLog("[INFO] router => %s\n", fpath)
formatAndFixImports(fpath) formatSourceCode(fpath)
} }
// formatAndFixImports formats source files (add imports, too) // formatSourceCode formats source files
func formatAndFixImports(filename string) { func formatSourceCode(filename string) {
cmd := exec.Command("goimports", "-w", filename) cmd := exec.Command("gofmt", "-w", filename)
cmd.Run() if err := cmd.Run(); err != nil {
ColorLog("[WARN] gofmt err: %s\n", err)
}
} }
// camelCase converts a _ delimited string to camel case // camelCase converts a _ delimited string to camel case
@ -662,6 +675,36 @@ func getFileName(tbName string) (filename string) {
return return
} }
func getPackagePath(curpath string) (packpath string) {
gopath := os.Getenv("GOPATH")
Debugf("gopath:%s", gopath)
if gopath == "" {
ColorLog("[ERRO] you should set GOPATH in the env")
os.Exit(2)
}
appsrcpath := ""
haspath := false
wgopath := filepath.SplitList(gopath)
for _, wg := range wgopath {
wg, _ = filepath.EvalSymlinks(path.Join(wg, "src"))
if filepath.HasPrefix(strings.ToLower(curpath), strings.ToLower(wg)) {
haspath = true
appsrcpath = wg
break
}
}
if !haspath {
ColorLog("[ERRO] Can't generate application code outside of GOPATH '%s'\n", gopath)
os.Exit(2)
}
packpath = strings.Join(strings.Split(curpath[len(appsrcpath)+1:], string(filepath.Separator)), "/")
return
}
const ( const (
STRUCT_MODEL_TPL = `package models STRUCT_MODEL_TPL = `package models
@ -670,6 +713,15 @@ const (
MODEL_TPL = `package models MODEL_TPL = `package models
import (
"errors"
"fmt"
"reflect"
"strings"
{{timePkg}}
"github.com/astaxie/beego/orm"
)
{{modelStruct}} {{modelStruct}}
func init() { func init() {
@ -799,8 +851,17 @@ func Delete{{modelName}}(id int) (err error) {
return return
} }
` `
CTRL_TPL = ` CTRL_TPL = `package controllers
package controllers
import (
"{{pkgPath}}/models"
"encoding/json"
"errors"
"strconv"
"strings"
"github.com/astaxie/beego"
)
// oprations for {{ctrlName}} // oprations for {{ctrlName}}
type {{ctrlName}}Controller struct { type {{ctrlName}}Controller struct {
@ -949,8 +1010,7 @@ func (this *{{ctrlName}}Controller) Delete() {
this.ServeJson() this.ServeJson()
} }
` `
ROUTER_TPL = ` ROUTER_TPL = `// @APIVersion 1.0.0
// @APIVersion 1.0.0
// @Title beego Test API // @Title beego Test API
// @Description beego has a very cool tools to autogenerate documents for your API // @Description beego has a very cool tools to autogenerate documents for your API
// @Contact astaxie@gmail.com // @Contact astaxie@gmail.com
@ -960,7 +1020,8 @@ func (this *{{ctrlName}}Controller) Delete() {
package routers package routers
import ( import (
"{{projectName}}/controllers" "{{pkgPath}}/controllers"
"github.com/astaxie/beego" "github.com/astaxie/beego"
) )
@ -972,10 +1033,10 @@ func init() {
} }
` `
NAMESPACE_TPL = ` NAMESPACE_TPL = `
beego.NSNamespace("/{{nameSpace}}", beego.NSNamespace("/{{nameSpace}}",
beego.NSInclude( beego.NSInclude(
&controllers.{{ctrlName}}Controller{}, &controllers.{{ctrlName}}Controller{},
), ),
), ),
` `
) )

View File

@ -17,7 +17,6 @@ package main
import ( import (
"fmt" "fmt"
"os" "os"
"os/exec"
"path" "path"
"strings" "strings"
"time" "time"
@ -60,12 +59,6 @@ func generateMigration(mname, upsql, downsql, curpath string) {
} }
} }
// formatSourceCode formats the source code using gofmt
func formatSourceCode(fpath string) {
cmd := exec.Command("gofmt", "-w", fpath)
cmd.Run()
}
const MIGRATION_TPL = `package main const MIGRATION_TPL = `package main
import ( import (

View File

@ -15,7 +15,7 @@ func generateModel(mname, fields, crupath string) {
i := strings.LastIndex(p[:len(p)-1], "/") i := strings.LastIndex(p[:len(p)-1], "/")
packageName = p[i+1 : len(p)-1] packageName = p[i+1 : len(p)-1]
} }
modelStruct, err := getStruct(modelName, fields) modelStruct, err, hastime := getStruct(modelName, fields)
if err != nil { if err != nil {
ColorLog("[ERRO] Could not genrate models struct: %s\n", err) ColorLog("[ERRO] Could not genrate models struct: %s\n", err)
os.Exit(2) os.Exit(2)
@ -36,9 +36,14 @@ func generateModel(mname, fields, crupath string) {
content := strings.Replace(modelTpl, "{{packageName}}", packageName, -1) content := strings.Replace(modelTpl, "{{packageName}}", packageName, -1)
content = strings.Replace(content, "{{modelName}}", modelName, -1) content = strings.Replace(content, "{{modelName}}", modelName, -1)
content = strings.Replace(content, "{{modelStruct}}", modelStruct, -1) content = strings.Replace(content, "{{modelStruct}}", modelStruct, -1)
if hastime {
content = strings.Replace(content, "{{timePkg}}", `"time"`, -1)
} else {
content = strings.Replace(content, "{{timePkg}}", "", -1)
}
f.WriteString(content) f.WriteString(content)
// gofmt generated source code // gofmt generated source code
formatAndFixImports(fpath) formatSourceCode(fpath)
ColorLog("[INFO] model file generated: %s\n", fpath) ColorLog("[INFO] model file generated: %s\n", fpath)
} else { } else {
// error creating file // error creating file
@ -47,49 +52,53 @@ func generateModel(mname, fields, crupath string) {
} }
} }
func getStruct(structname, fields string) (string, error) { func getStruct(structname, fields string) (string, error, bool) {
if fields == "" { if fields == "" {
return "", errors.New("fields can't empty") return "", errors.New("fields can't empty"), false
} }
hastime := false
structStr := "type " + structname + " struct{\n" structStr := "type " + structname + " struct{\n"
fds := strings.Split(fields, ",") fds := strings.Split(fields, ",")
for i, v := range fds { for i, v := range fds {
kv := strings.SplitN(v, ":", 2) kv := strings.SplitN(v, ":", 2)
if len(kv) != 2 { if len(kv) != 2 {
return "", errors.New("the filds format is wrong. should key:type,key:type " + v) return "", errors.New("the filds format is wrong. should key:type,key:type " + v), false
} }
typ, tag := getType(kv[1]) typ, tag, hastimeinner := getType(kv[1])
if typ == "" { if typ == "" {
return "", errors.New("the filds format is wrong. should key:type,key:type " + v) return "", errors.New("the filds format is wrong. should key:type,key:type " + v), false
} }
if i == 0 && strings.ToLower(kv[0]) != "id" { if i == 0 && strings.ToLower(kv[0]) != "id" {
structStr = structStr + "Id int64 `orm:\"auto\"`\n" structStr = structStr + "Id int64 `orm:\"auto\"`\n"
} }
if hastimeinner {
hastime = true
}
structStr = structStr + camelString(kv[0]) + " " + typ + " " + tag + "\n" structStr = structStr + camelString(kv[0]) + " " + typ + " " + tag + "\n"
} }
structStr += "}\n" structStr += "}\n"
return structStr, nil return structStr, nil, hastime
} }
// fields support type // fields support type
// http://beego.me/docs/mvc/model/models.md#mysql // http://beego.me/docs/mvc/model/models.md#mysql
func getType(ktype string) (kt, tag string) { func getType(ktype string) (kt, tag string, hasTime bool) {
kv := strings.SplitN(ktype, ":", 2) kv := strings.SplitN(ktype, ":", 2)
switch kv[0] { switch kv[0] {
case "string": case "string":
if len(kv) == 2 { if len(kv) == 2 {
return "string", "`orm:\"size(" + kv[1] + ")\"`" return "string", "`orm:\"size(" + kv[1] + ")\"`", false
} else { } else {
return "string", "`orm:\"size(128)\"`" return "string", "`orm:\"size(128)\"`", false
} }
case "text": case "text":
return "string", "`orm:\"type(longtext)\"`" return "string", "`orm:\"type(longtext)\"`", false
case "auto": case "auto":
return "int64", "`orm:\"auto\"`" return "int64", "`orm:\"auto\"`", false
case "pk": case "pk":
return "int64", "`orm:\"pk\"`" return "int64", "`orm:\"pk\"`", false
case "datetime": case "datetime":
return "time.Time", "`orm:\"type(datetime)\"`" return "time.Time", "`orm:\"type(datetime)\"`", true
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
fallthrough fallthrough
case "uint", "uint8", "uint16", "uint32", "uint64": case "uint", "uint8", "uint16", "uint32", "uint64":
@ -97,11 +106,11 @@ func getType(ktype string) (kt, tag string) {
case "bool": case "bool":
fallthrough fallthrough
case "float32", "float64": case "float32", "float64":
return kv[0], "" return kv[0], "", false
case "float": case "float":
return "float64", "" return "float64", "", false
} }
return "", "" return "", "", false
} }
var modelTpl = `package {{packageName}} var modelTpl = `package {{packageName}}
@ -111,8 +120,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time" {{timePkg}}
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
) )