1
0
mirror of https://github.com/beego/bee.git synced 2024-11-25 20:10:55 +00:00
This commit is contained in:
Sergey Lanzman 2016-07-23 02:05:01 +03:00
parent b022ab3277
commit bc963e0070
12 changed files with 93 additions and 101 deletions

12
bale.go
View File

@ -68,7 +68,7 @@ func runBale(cmd *Command, args []string) int {
// Generate auto-uncompress function. // Generate auto-uncompress function.
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.WriteString(fmt.Sprintf(_BALE_HEADER, conf.Bale.Import, buf.WriteString(fmt.Sprintf(BaleHeader, conf.Bale.Import,
strings.Join(resFiles, "\",\n\t\t\""), strings.Join(resFiles, "\",\n\t\t\""),
strings.Join(resFiles, ",\n\t\tbale.R"))) strings.Join(resFiles, ",\n\t\tbale.R")))
@ -90,7 +90,7 @@ func runBale(cmd *Command, args []string) int {
} }
const ( const (
_BALE_HEADER = `package main BaleHeader = `package main
import( import(
"os" "os"
@ -178,7 +178,7 @@ func walkFn(resPath string, info os.FileInfo, err error) error {
defer fw.Close() defer fw.Close()
// Write header. // Write header.
fmt.Fprintf(fw, _HEADER, resPath) fmt.Fprintf(fw, Header, resPath)
// Copy and compress data. // Copy and compress data.
gz := gzip.NewWriter(&ByteWriter{Writer: fw}) gz := gzip.NewWriter(&ByteWriter{Writer: fw})
@ -186,7 +186,7 @@ func walkFn(resPath string, info os.FileInfo, err error) error {
gz.Close() gz.Close()
// Write footer. // Write footer.
fmt.Fprint(fw, _FOOTER) fmt.Fprint(fw, Footer)
resFiles = append(resFiles, resPath) resFiles = append(resFiles, resPath)
return nil return nil
@ -202,7 +202,7 @@ func filterSuffix(name string) bool {
} }
const ( const (
_HEADER = `package bale Header = `package bale
import( import(
"bytes" "bytes"
@ -212,7 +212,7 @@ import(
func R%s() []byte { func R%s() []byte {
gz, err := gzip.NewReader(bytes.NewBuffer([]byte{` gz, err := gzip.NewReader(bytes.NewBuffer([]byte{`
_FOOTER = ` Footer = `
})) }))
if err != nil { if err != nil {

View File

@ -19,7 +19,7 @@ import (
"os" "os"
) )
const CONF_VER = 0 const ConfVer = 0
var defaultConf = `{ var defaultConf = `{
"version": 0, "version": 0,
@ -91,7 +91,7 @@ func loadConfig() error {
} }
// Check format version. // Check format version.
if conf.Version != CONF_VER { if conf.Version != ConfVer {
ColorLog("[WARN] Your bee.json is out-of-date, please update!\n") ColorLog("[WARN] Your bee.json is out-of-date, please update!\n")
ColorLog("[HINT] Compare bee.json under bee source code path and yours\n") ColorLog("[HINT] Compare bee.json under bee source code path and yours\n")
} }

View File

@ -29,9 +29,9 @@ import (
) )
const ( const (
O_MODEL byte = 1 << iota OModel byte = 1 << iota
O_CONTROLLER OController
O_ROUTER ORouter
) )
// DbTransformer has method to reverse engineer a database schema to restful api code // DbTransformer has method to reverse engineer a database schema to restful api code
@ -259,11 +259,11 @@ func generateAppcode(driver, connStr, level, tables, currpath string) {
var mode byte var mode byte
switch level { switch level {
case "1": case "1":
mode = O_MODEL mode = OModel
case "2": case "2":
mode = O_MODEL | O_CONTROLLER mode = OModel | OController
case "3": case "3":
mode = O_MODEL | O_CONTROLLER | O_ROUTER mode = OModel | OController | ORouter
default: default:
ColorLog("[ERRO] Invalid 'level' option: %s\n", level) ColorLog("[ERRO] Invalid 'level' option: %s\n", level)
ColorLog("[HINT] Level must be either 1, 2 or 3\n") ColorLog("[HINT] Level must be either 1, 2 or 3\n")
@ -505,10 +505,9 @@ func (*MysqlDB) GetGoDataType(sqlType string) (goType string) {
typeMapping = typeMappingMysql typeMapping = typeMappingMysql
if v, ok := typeMapping[sqlType]; ok { if v, ok := typeMapping[sqlType]; ok {
return v return v
} else { }
ColorLog("[ERRO] data type (%s) not found!\n", sqlType) ColorLog("[ERRO] data type (%s) not found!\n", sqlType)
os.Exit(2) os.Exit(2)
}
return goType return goType
} }
@ -692,22 +691,21 @@ func (postgresDB *PostgresDB) GetColumns(db *sql.DB, table *Table, blackList map
func (*PostgresDB) GetGoDataType(sqlType string) (goType string) { func (*PostgresDB) GetGoDataType(sqlType string) (goType string) {
if v, ok := typeMappingPostgres[sqlType]; ok { if v, ok := typeMappingPostgres[sqlType]; ok {
return v return v
} else { }
ColorLog("[ERRO] data type (%s) not found!\n", sqlType) ColorLog("[ERRO] data type (%s) not found!\n", sqlType)
os.Exit(2) os.Exit(2)
}
return goType return goType
} }
// deleteAndRecreatePaths removes several directories completely // deleteAndRecreatePaths removes several directories completely
func createPaths(mode byte, paths *MvcPath) { func createPaths(mode byte, paths *MvcPath) {
if (mode & O_MODEL) == O_MODEL { if (mode & OModel) == OModel {
os.Mkdir(paths.ModelPath, 0777) os.Mkdir(paths.ModelPath, 0777)
} }
if (mode & O_CONTROLLER) == O_CONTROLLER { if (mode & OController) == OController {
os.Mkdir(paths.ControllerPath, 0777) os.Mkdir(paths.ControllerPath, 0777)
} }
if (mode & O_ROUTER) == O_ROUTER { if (mode & ORouter) == ORouter {
os.Mkdir(paths.RouterPath, 0777) os.Mkdir(paths.RouterPath, 0777)
} }
} }
@ -716,15 +714,15 @@ func createPaths(mode byte, paths *MvcPath) {
// It will wipe the following directories and recreate them:./models, ./controllers, ./routers // It will wipe the following directories and recreate them:./models, ./controllers, ./routers
// Newly geneated files will be inside these folders. // Newly geneated files will be inside these folders.
func writeSourceFiles(pkgPath string, tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) { func writeSourceFiles(pkgPath string, tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) {
if (O_MODEL & mode) == O_MODEL { if (OModel & mode) == OModel {
ColorLog("[INFO] Creating model files...\n") ColorLog("[INFO] Creating model files...\n")
writeModelFiles(tables, paths.ModelPath, selectedTables) writeModelFiles(tables, paths.ModelPath, selectedTables)
} }
if (O_CONTROLLER & mode) == O_CONTROLLER { if (OController & mode) == OController {
ColorLog("[INFO] Creating controller files...\n") ColorLog("[INFO] Creating controller files...\n")
writeControllerFiles(tables, paths.ControllerPath, selectedTables, pkgPath) writeControllerFiles(tables, paths.ControllerPath, selectedTables, pkgPath)
} }
if (O_ROUTER & mode) == O_ROUTER { if (ORouter & mode) == ORouter {
ColorLog("[INFO] Creating router files...\n") ColorLog("[INFO] Creating router files...\n")
writeRouterFile(tables, paths.RouterPath, selectedTables, pkgPath) writeRouterFile(tables, paths.RouterPath, selectedTables, pkgPath)
} }
@ -764,9 +762,9 @@ func writeModelFiles(tables []*Table, mPath string, selectedTables map[string]bo
} }
template := "" template := ""
if tb.Pk == "" { if tb.Pk == "" {
template = STRUCT_MODEL_TPL template = StructModelTPL
} else { } else {
template = MODEL_TPL template = ModelTPL
} }
fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1) fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1)
fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1) fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1)
@ -825,7 +823,7 @@ func writeControllerFiles(tables []*Table, cPath string, selectedTables map[stri
continue continue
} }
} }
fileStr := strings.Replace(CTRL_TPL, "{{ctrlName}}", camelCase(tb.Name), -1) fileStr := strings.Replace(CtrlTPL, "{{ctrlName}}", camelCase(tb.Name), -1)
fileStr = strings.Replace(fileStr, "{{pkgPath}}", pkgPath, -1) fileStr = strings.Replace(fileStr, "{{pkgPath}}", pkgPath, -1)
if _, err := f.WriteString(fileStr); err != nil { if _, err := f.WriteString(fileStr); err != nil {
ColorLog("[ERRO] Could not write controller file to %s\n", fpath) ColorLog("[ERRO] Could not write controller file to %s\n", fpath)
@ -851,13 +849,13 @@ func writeRouterFile(tables []*Table, rPath string, selectedTables map[string]bo
continue continue
} }
// add name spaces // add name spaces
nameSpace := strings.Replace(NAMESPACE_TPL, "{{nameSpace}}", tb.Name, -1) nameSpace := strings.Replace(NamespaceTPL, "{{nameSpace}}", tb.Name, -1)
nameSpace = strings.Replace(nameSpace, "{{ctrlName}}", camelCase(tb.Name), -1) nameSpace = strings.Replace(nameSpace, "{{ctrlName}}", camelCase(tb.Name), -1)
nameSpaces = append(nameSpaces, nameSpace) nameSpaces = append(nameSpaces, nameSpace)
} }
// add export controller // add export controller
fpath := path.Join(rPath, "router.go") fpath := path.Join(rPath, "router.go")
routerStr := strings.Replace(ROUTER_TPL, "{{nameSpaces}}", strings.Join(nameSpaces, ""), 1) routerStr := strings.Replace(RouterTPL, "{{nameSpaces}}", strings.Join(nameSpaces, ""), 1)
routerStr = strings.Replace(routerStr, "{{pkgPath}}", pkgPath, 1) routerStr = strings.Replace(routerStr, "{{pkgPath}}", pkgPath, 1)
var f *os.File var f *os.File
var err error var err error
@ -1001,12 +999,12 @@ func getPackagePath(curpath string) (packpath string) {
} }
const ( const (
STRUCT_MODEL_TPL = `package models StructModelTPL = `package models
{{importTimePkg}} {{importTimePkg}}
{{modelStruct}} {{modelStruct}}
` `
MODEL_TPL = `package models ModelTPL = `package models
import ( import (
"errors" "errors"
@ -1150,7 +1148,7 @@ func Delete{{modelName}}(id int) (err error) {
return return
} }
` `
CTRL_TPL = `package controllers CtrlTPL = `package controllers
import ( import (
"{{pkgPath}}/models" "{{pkgPath}}/models"
@ -1316,7 +1314,7 @@ func (c *{{ctrlName}}Controller) Delete() {
c.ServeJSON() c.ServeJSON()
} }
` `
ROUTER_TPL = `// @APIVersion 1.0.0 RouterTPL = `// @APIVersion 1.0.0
// @Title beego Test API // @Title beego Test API
// @Description beego has a very cool tools to autogenerate documents for your API // @Description beego has a very cool tools to autogenerate documents for your API
// @Contact astaxie@gmail.com // @Contact astaxie@gmail.com
@ -1338,7 +1336,7 @@ func init() {
beego.AddNamespace(ns) beego.AddNamespace(ns)
} }
` `
NAMESPACE_TPL = ` NamespaceTPL = `
beego.NSNamespace("/{{nameSpace}}", beego.NSNamespace("/{{nameSpace}}",
beego.NSInclude( beego.NSInclude(
&controllers.{{ctrlName}}Controller{}, &controllers.{{ctrlName}}Controller{},

View File

@ -393,7 +393,7 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
if j == 0 || j == 1 { if j == 0 || j == 1 {
st[j] = string(tmp) st[j] = string(tmp)
tmp = make([]rune, 0) tmp = make([]rune, 0)
j += 1 j++
start = false start = false
if j == 1 { if j == 1 {
continue continue
@ -655,17 +655,15 @@ func typeAnalyser(f *ast.Field) (isSlice bool, realType string) {
} }
if star, ok := arr.Elt.(*ast.StarExpr); ok { if star, ok := arr.Elt.(*ast.StarExpr); ok {
return true, fmt.Sprint(star.X) return true, fmt.Sprint(star.X)
} else { }
return true, fmt.Sprint(arr.Elt) return true, fmt.Sprint(arr.Elt)
} }
} else {
switch t := f.Type.(type) { switch t := f.Type.(type) {
case *ast.StarExpr: case *ast.StarExpr:
return false, fmt.Sprint(t.X) return false, fmt.Sprint(t.X)
} }
return false, fmt.Sprint(f.Type) return false, fmt.Sprint(f.Type)
} }
}
func isBasicType(Type string) bool { func isBasicType(Type string) bool {
for _, v := range basicTypes { for _, v := range basicTypes {
@ -688,7 +686,7 @@ var basicTypes = []string{
} }
// regexp get json tag // regexp get json tag
func grepJsonTag(tag string) string { func grepJSONTag(tag string) string {
r, _ := regexp.Compile(`json:"([^"]*)"`) r, _ := regexp.Compile(`json:"([^"]*)"`)
matches := r.FindAllStringSubmatch(tag, -1) matches := r.FindAllStringSubmatch(tag, -1)
if len(matches) > 0 { if len(matches) > 0 {

View File

@ -31,11 +31,11 @@ func generateHproseAppcode(driver, connStr, level, tables, currpath string) {
var mode byte var mode byte
switch level { switch level {
case "1": case "1":
mode = O_MODEL mode = OModel
case "2": case "2":
mode = O_MODEL | O_CONTROLLER mode = OModel | OController
case "3": case "3":
mode = O_MODEL | O_CONTROLLER | O_ROUTER mode = OModel | OController | ORouter
default: default:
ColorLog("[ERRO] Invalid 'level' option: %s\n", level) ColorLog("[ERRO] Invalid 'level' option: %s\n", level)
ColorLog("[HINT] Level must be either 1, 2 or 3\n") ColorLog("[HINT] Level must be either 1, 2 or 3\n")
@ -90,7 +90,7 @@ func genHprose(dbms, connStr string, mode byte, selectedTableNames map[string]bo
// It will wipe the following directories and recreate them:./models, ./controllers, ./routers // It will wipe the following directories and recreate them:./models, ./controllers, ./routers
// Newly geneated files will be inside these folders. // Newly geneated files will be inside these folders.
func writeHproseSourceFiles(pkgPath string, tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) { func writeHproseSourceFiles(pkgPath string, tables []*Table, mode byte, paths *MvcPath, selectedTables map[string]bool) {
if (O_MODEL & mode) == O_MODEL { if (OModel & mode) == OModel {
ColorLog("[INFO] Creating model files...\n") ColorLog("[INFO] Creating model files...\n")
writeHproseModelFiles(tables, paths.ModelPath, selectedTables) writeHproseModelFiles(tables, paths.ModelPath, selectedTables)
} }
@ -130,10 +130,10 @@ func writeHproseModelFiles(tables []*Table, mPath string, selectedTables map[str
} }
template := "" template := ""
if tb.Pk == "" { if tb.Pk == "" {
template = HPROSE_STRUCT_MODEL_TPL template = HproseStructModelTPL
} else { } else {
template = HPROSE_MODEL_TPL template = HproseModelTPL
hproseAddFunctions = append(hproseAddFunctions, strings.Replace(HPROSE_ADDFUNCTION, "{{modelName}}", camelCase(tb.Name), -1)) hproseAddFunctions = append(hproseAddFunctions, strings.Replace(HproseAddFunction, "{{modelName}}", camelCase(tb.Name), -1))
} }
fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1) fileStr := strings.Replace(template, "{{modelStruct}}", tb.String(), 1)
fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1) fileStr = strings.Replace(fileStr, "{{modelName}}", camelCase(tb.Name), -1)
@ -157,7 +157,7 @@ func writeHproseModelFiles(tables []*Table, mPath string, selectedTables map[str
} }
const ( const (
HPROSE_ADDFUNCTION = ` HproseAddFunction = `
// publish about {{modelName}} function // publish about {{modelName}} function
service.AddFunction("Add{{modelName}}", models.Add{{modelName}}) service.AddFunction("Add{{modelName}}", models.Add{{modelName}})
service.AddFunction("Get{{modelName}}ById", models.Get{{modelName}}ById) service.AddFunction("Get{{modelName}}ById", models.Get{{modelName}}ById)
@ -166,12 +166,12 @@ const (
service.AddFunction("Delete{{modelName}}", models.Delete{{modelName}}) service.AddFunction("Delete{{modelName}}", models.Delete{{modelName}})
` `
HPROSE_STRUCT_MODEL_TPL = `package models HproseStructModelTPL = `package models
{{importTimePkg}} {{importTimePkg}}
{{modelStruct}} {{modelStruct}}
` `
HPROSE_MODEL_TPL = `package models HproseModelTPL = `package models
import ( import (
"errors" "errors"

View File

@ -23,15 +23,15 @@ import (
) )
const ( const (
M_PATH = "migrations" MPath = "migrations"
M_DATE_FORMAT = "20060102_150405" MDateFormat = "20060102_150405"
) )
// generateMigration generates migration file template for database schema update. // generateMigration generates migration file template for database schema update.
// The generated file template consists of an up() method for updating schema and // The generated file template consists of an up() method for updating schema and
// a down() method for reverting the update. // a down() method for reverting the update.
func generateMigration(mname, upsql, downsql, curpath string) { func generateMigration(mname, upsql, downsql, curpath string) {
migrationFilePath := path.Join(curpath, "database", M_PATH) migrationFilePath := path.Join(curpath, "database", MPath)
if _, err := os.Stat(migrationFilePath); os.IsNotExist(err) { if _, err := os.Stat(migrationFilePath); os.IsNotExist(err) {
// create migrations directory // create migrations directory
if err := os.MkdirAll(migrationFilePath, 0777); err != nil { if err := os.MkdirAll(migrationFilePath, 0777); err != nil {
@ -40,11 +40,11 @@ func generateMigration(mname, upsql, downsql, curpath string) {
} }
} }
// create file // create file
today := time.Now().Format(M_DATE_FORMAT) today := time.Now().Format(MDateFormat)
fpath := path.Join(migrationFilePath, fmt.Sprintf("%s_%s.go", today, mname)) fpath := path.Join(migrationFilePath, fmt.Sprintf("%s_%s.go", today, mname))
if f, err := os.OpenFile(fpath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666); err == nil { if f, err := os.OpenFile(fpath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666); err == nil {
defer f.Close() defer f.Close()
content := strings.Replace(MIGRATION_TPL, "{{StructName}}", camelCase(mname)+"_"+today, -1) content := strings.Replace(MigrationTPL, "{{StructName}}", camelCase(mname)+"_"+today, -1)
content = strings.Replace(content, "{{CurrTime}}", today, -1) content = strings.Replace(content, "{{CurrTime}}", today, -1)
content = strings.Replace(content, "{{UpSQL}}", upsql, -1) content = strings.Replace(content, "{{UpSQL}}", upsql, -1)
content = strings.Replace(content, "{{DownSQL}}", downsql, -1) content = strings.Replace(content, "{{DownSQL}}", downsql, -1)
@ -59,7 +59,7 @@ func generateMigration(mname, upsql, downsql, curpath string) {
} }
} }
const MIGRATION_TPL = `package main const MigrationTPL = `package main
import ( import (
"github.com/astaxie/beego/migration" "github.com/astaxie/beego/migration"

View File

@ -2,10 +2,10 @@ package main
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"path" "path"
"strings" "strings"
"fmt"
) )
func generateModel(mname, fields, crupath string) { func generateModel(mname, fields, crupath string) {
@ -16,7 +16,7 @@ func generateModel(mname, fields, crupath string) {
i := strings.LastIndex(p[:len(p)-1], "/") i := strings.LastIndex(p[:len(p)-1], "/")
packageName = p[i+1 : len(p)-1] packageName = p[i+1 : len(p)-1]
} }
modelStruct, err, hastime := getStruct(modelName, fields) modelStruct, hastime, err := getStruct(modelName, fields)
if err != nil { if err != nil {
ColorLog("[ERRO] Could not genrate models struct: %s\n", err) ColorLog("[ERRO] Could not genrate models struct: %s\n", err)
os.Exit(2) os.Exit(2)
@ -53,9 +53,9 @@ func generateModel(mname, fields, crupath string) {
} }
} }
func getStruct(structname, fields string) (string, error, bool) { func getStruct(structname, fields string) (string, bool, error) {
if fields == "" { if fields == "" {
return "", errors.New("fields can't empty"), false return "", false, errors.New("fields can't empty")
} }
hastime := false hastime := false
structStr := "type " + structname + " struct{\n" structStr := "type " + structname + " struct{\n"
@ -63,11 +63,11 @@ func getStruct(structname, fields string) (string, error, bool) {
for i, v := range fds { for i, v := range fds {
kv := strings.SplitN(v, ":", 2) kv := strings.SplitN(v, ":", 2)
if len(kv) != 2 { if len(kv) != 2 {
return "", errors.New("the filds format is wrong. should key:type,key:type " + v), false return "", false, errors.New("the filds format is wrong. should key:type,key:type " + v)
} }
typ, tag, hastimeinner := getType(kv[1]) typ, tag, hastimeinner := getType(kv[1])
if typ == "" { if typ == "" {
return "", errors.New("the filds format is wrong. should key:type,key:type " + v), false return "", false, errors.New("the filds format is wrong. should key:type,key:type " + v)
} }
if i == 0 && strings.ToLower(kv[0]) != "id" { if i == 0 && strings.ToLower(kv[0]) != "id" {
structStr = structStr + "Id int64 `orm:\"auto\"`\n" structStr = structStr + "Id int64 `orm:\"auto\"`\n"
@ -78,7 +78,7 @@ func getStruct(structname, fields string) (string, error, bool) {
structStr = structStr + camelString(kv[0]) + " " + typ + " " + tag + "\n" structStr = structStr + camelString(kv[0]) + " " + typ + " " + tag + "\n"
} }
structStr += "}\n" structStr += "}\n"
return structStr, nil, hastime return structStr, hastime, nil
} }
// fields support type // fields support type
@ -89,9 +89,8 @@ func getType(ktype string) (kt, tag string, hasTime bool) {
case "string": case "string":
if len(kv) == 2 { if len(kv) == 2 {
return "string", "`orm:\"size(" + kv[1] + ")\"`", false return "string", "`orm:\"size(" + kv[1] + ")\"`", false
} else {
return "string", "`orm:\"size(128)\"`", false
} }
return "string", "`orm:\"size(128)\"`", false
case "text": case "text":
return "string", "`orm:\"type(longtext)\"`", false return "string", "`orm:\"type(longtext)\"`", false
case "auto": case "auto":

View File

@ -54,11 +54,11 @@ func generateSQLFromFields(fields string) string {
typ, tag := "", "" typ, tag := "", ""
switch driver { switch driver {
case "mysql": case "mysql":
typ, tag = getSqlTypeMysql(kv[1]) typ, tag = getSQLTypeMysql(kv[1])
case "postgres": case "postgres":
typ, tag = getSqlTypePostgresql(kv[1]) typ, tag = getSQLTypePostgresql(kv[1])
default: default:
typ, tag = getSqlTypeMysql(kv[1]) typ, tag = getSQLTypeMysql(kv[1])
} }
if typ == "" { if typ == "" {
ColorLog("[ERRO] Fields format is wrong. Should be: key:type,key:type " + v + "\n") ColorLog("[ERRO] Fields format is wrong. Should be: key:type,key:type " + v + "\n")
@ -90,15 +90,14 @@ func generateSQLFromFields(fields string) string {
return sql return sql
} }
func getSqlTypeMysql(ktype string) (tp, tag string) { func getSQLTypeMysql(ktype string) (tp, tag string) {
kv := strings.SplitN(ktype, ":", 2) kv := strings.SplitN(ktype, ":", 2)
switch kv[0] { switch kv[0] {
case "string": case "string":
if len(kv) == 2 { if len(kv) == 2 {
return "varchar(" + kv[1] + ") NOT NULL", "" return "varchar(" + kv[1] + ") NOT NULL", ""
} else {
return "varchar(128) NOT NULL", ""
} }
return "varchar(128) NOT NULL", ""
case "text": case "text":
return "longtext NOT NULL", "" return "longtext NOT NULL", ""
case "auto": case "auto":
@ -121,15 +120,14 @@ func getSqlTypeMysql(ktype string) (tp, tag string) {
return "", "" return "", ""
} }
func getSqlTypePostgresql(ktype string) (tp, tag string) { func getSQLTypePostgresql(ktype string) (tp, tag string) {
kv := strings.SplitN(ktype, ":", 2) kv := strings.SplitN(ktype, ":", 2)
switch kv[0] { switch kv[0] {
case "string": case "string":
if len(kv) == 2 { if len(kv) == 2 {
return "char(" + kv[1] + ") NOT NULL", "" return "char(" + kv[1] + ") NOT NULL", ""
} else {
return "TEXT NOT NULL", ""
} }
return "TEXT NOT NULL", ""
case "text": case "text":
return "TEXT NOT NULL", "" return "TEXT NOT NULL", ""
case "auto", "pk": case "auto", "pk":

View File

@ -167,23 +167,23 @@ func migrate(goal, crupath, driver, connStr string) {
// checkForSchemaUpdateTable checks the existence of migrations table. // checkForSchemaUpdateTable checks the existence of migrations table.
// It checks for the proper table structures and creates the table using MYSQL_MIGRATION_DDL if it does not exist. // It checks for the proper table structures and creates the table using MYSQL_MIGRATION_DDL if it does not exist.
func checkForSchemaUpdateTable(db *sql.DB, driver string) { func checkForSchemaUpdateTable(db *sql.DB, driver string) {
showTableSql := showMigrationsTableSql(driver) showTableSQL := showMigrationsTableSQL(driver)
if rows, err := db.Query(showTableSql); err != nil { if rows, err := db.Query(showTableSQL); err != nil {
ColorLog("[ERRO] Could not show migrations table: %s\n", err) ColorLog("[ERRO] Could not show migrations table: %s\n", err)
os.Exit(2) os.Exit(2)
} else if !rows.Next() { } else if !rows.Next() {
// no migrations table, create anew // no migrations table, create anew
createTableSql := createMigrationsTableSql(driver) createTableSQL := createMigrationsTableSQL(driver)
ColorLog("[INFO] Creating 'migrations' table...\n") ColorLog("[INFO] Creating 'migrations' table...\n")
if _, err := db.Query(createTableSql); err != nil { if _, err := db.Query(createTableSQL); err != nil {
ColorLog("[ERRO] Could not create migrations table: %s\n", err) ColorLog("[ERRO] Could not create migrations table: %s\n", err)
os.Exit(2) os.Exit(2)
} }
} }
// checking that migrations table schema are expected // checking that migrations table schema are expected
selectTableSql := selectMigrationsTableSql(driver) selectTableSQL := selectMigrationsTableSQL(driver)
if rows, err := db.Query(selectTableSql); err != nil { if rows, err := db.Query(selectTableSQL); err != nil {
ColorLog("[ERRO] Could not show columns of migrations table: %s\n", err) ColorLog("[ERRO] Could not show columns of migrations table: %s\n", err)
os.Exit(2) os.Exit(2)
} else { } else {
@ -219,7 +219,7 @@ func checkForSchemaUpdateTable(db *sql.DB, driver string) {
} }
} }
func showMigrationsTableSql(driver string) string { func showMigrationsTableSQL(driver string) string {
switch driver { switch driver {
case "mysql": case "mysql":
return "SHOW TABLES LIKE 'migrations'" return "SHOW TABLES LIKE 'migrations'"
@ -230,18 +230,18 @@ func showMigrationsTableSql(driver string) string {
} }
} }
func createMigrationsTableSql(driver string) string { func createMigrationsTableSQL(driver string) string {
switch driver { switch driver {
case "mysql": case "mysql":
return MYSQL_MIGRATION_DDL return MYSQLMigrationDDL
case "postgres": case "postgres":
return POSTGRES_MIGRATION_DDL return POSTGRESMigrationDDL
default: default:
return MYSQL_MIGRATION_DDL return MYSQLMigrationDDL
} }
} }
func selectMigrationsTableSql(driver string) string { func selectMigrationsTableSQL(driver string) string {
switch driver { switch driver {
case "mysql": case "mysql":
return "DESC migrations" return "DESC migrations"
@ -290,7 +290,7 @@ func writeMigrationSourceFile(dir, source, driver, connStr string, latestTime in
ColorLog("[ERRO] Could not create file: %s\n", err) ColorLog("[ERRO] Could not create file: %s\n", err)
os.Exit(2) os.Exit(2)
} else { } else {
content := strings.Replace(MIGRATION_MAIN_TPL, "{{DBDriver}}", driver, -1) content := strings.Replace(MigrationMainTPL, "{{DBDriver}}", driver, -1)
content = strings.Replace(content, "{{ConnStr}}", connStr, -1) content = strings.Replace(content, "{{ConnStr}}", connStr, -1)
content = strings.Replace(content, "{{LatestTime}}", strconv.FormatInt(latestTime, 10), -1) content = strings.Replace(content, "{{LatestTime}}", strconv.FormatInt(latestTime, 10), -1)
content = strings.Replace(content, "{{LatestName}}", latestName, -1) content = strings.Replace(content, "{{LatestName}}", latestName, -1)
@ -369,7 +369,7 @@ func formatShellOutput(o string) {
} }
const ( const (
MIGRATION_MAIN_TPL = `package main MigrationMainTPL = `package main
import( import(
"os" "os"
@ -408,7 +408,7 @@ func main(){
} }
` `
MYSQL_MIGRATION_DDL = ` MYSQLMigrationDDL = `
CREATE TABLE migrations ( CREATE TABLE migrations (
id_migration int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', id_migration int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key',
name varchar(255) DEFAULT NULL COMMENT 'migration name, unique', name varchar(255) DEFAULT NULL COMMENT 'migration name, unique',
@ -420,7 +420,7 @@ CREATE TABLE migrations (
) ENGINE=InnoDB DEFAULT CHARSET=utf8 ) ENGINE=InnoDB DEFAULT CHARSET=utf8
` `
POSTGRES_MIGRATION_DDL = ` POSTGRESMigrationDDL = `
CREATE TYPE migrations_status AS ENUM('update', 'rollback'); CREATE TYPE migrations_status AS ENUM('update', 'rollback');
CREATE TABLE migrations ( CREATE TABLE migrations (

View File

@ -246,9 +246,8 @@ func (wft *walkFileTree) walkLeaf(fpath string, fi os.FileInfo, err error) error
} }
wft.allfiles[name] = true wft.allfiles[name] = true
return err return err
} else {
return err
} }
return err
} }
func (wft *walkFileTree) iterDirectory(fpath string, fi os.FileInfo) error { func (wft *walkFileTree) iterDirectory(fpath string, fi os.FileInfo) error {

2
run.go
View File

@ -114,7 +114,7 @@ func runApp(cmd *Command, args []string) int {
if downdoc == "true" { if downdoc == "true" {
if _, err := os.Stat(path.Join(crupath, "swagger")); err != nil { if _, err := os.Stat(path.Join(crupath, "swagger")); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
downloadFromUrl(swaggerlink, "swagger.zip") downloadFromURL(swaggerlink, "swagger.zip")
unzipAndDelete("swagger.zip", "swagger") unzipAndDelete("swagger.zip", "swagger")
} }
} }

View File

@ -59,7 +59,7 @@ func init() {
func runDocs(cmd *Command, args []string) int { func runDocs(cmd *Command, args []string) int {
if isDownload == "true" { if isDownload == "true" {
downloadFromUrl(swaggerlink, "swagger.zip") downloadFromURL(swaggerlink, "swagger.zip")
err := unzipAndDelete("swagger.zip", "swagger") err := unzipAndDelete("swagger.zip", "swagger")
if err != nil { if err != nil {
fmt.Println("has err exet unzipAndDelete", err) fmt.Println("has err exet unzipAndDelete", err)
@ -77,7 +77,7 @@ func runDocs(cmd *Command, args []string) int {
return 0 return 0
} }
func downloadFromUrl(url, fileName string) { func downloadFromURL(url, fileName string) {
fmt.Println("Downloading", url, "to", fileName) fmt.Println("Downloading", url, "to", fileName)
output, err := os.Create(fileName) output, err := os.Create(fileName)