diff --git a/apiapp.go b/apiapp.go index 8ee8c85..ecf61ef 100644 --- a/apiapp.go +++ b/apiapp.go @@ -30,7 +30,7 @@ create an api application base on beego framework bee api [appname] [-tables=""] [-driver=mysql] [-conn=root:@tcp(127.0.0.1:3306)/test] -tables: a list of table names separated by ',', default is empty, indicating all tables - -driver: [mysql | postgresql | sqlite], the default is mysql + -driver: [mysql | postgres | sqlite], the default is mysql -conn: the connection string used by the driver, the default is '' if conn is empty will create a example api application. otherwise generate api application based on an existing database. @@ -89,11 +89,11 @@ import ( "github.com/astaxie/beego" "github.com/astaxie/beego/orm" - _ "github.com/go-sql-driver/mysql" + {{.DriverPkg}} ) func init() { - orm.RegisterDataBase("default", "mysql", "{{.conn}}") + orm.RegisterDataBase("default", "{{.DriverName}}", "{{.conn}}") } func main() { @@ -574,9 +574,16 @@ func createapi(cmd *Command, args []string) int { if conn != "" { fmt.Println("create main.go:", path.Join(apppath, "main.go")) + maingoContent := strings.Replace(apiMainconngo, "{{.Appname}}", packpath, -1) + maingoContent = strings.Replace(maingoContent, "{{.DriverName}}", string(driver), -1) + if driver == "mysql" { + maingoContent = strings.Replace(maingoContent, "{{.DriverPkg}}", `_ "github.com/go-sql-driver/mysql"`, -1) + } else if driver == "postgres" { + maingoContent = strings.Replace(maingoContent, "{{.DriverPkg}}", `_ "github.com/lib/pq"`, -1) + } writetofile(path.Join(apppath, "main.go"), strings.Replace( - strings.Replace(apiMainconngo, "{{.Appname}}", packpath, -1), + maingoContent, "{{.conn}}", conn.String(), -1, @@ -584,8 +591,8 @@ func createapi(cmd *Command, args []string) int { ) ColorLog("[INFO] Using '%s' as 'driver'\n", driver) ColorLog("[INFO] Using '%s' as 'conn'\n", conn) - ColorLog("[INFO] Using '%s' as 'tables'", tables) - generateAppcode(string(driver), string(conn), "3", string(tables), path.Join(curpath, packpath)) + ColorLog("[INFO] Using '%s' as 'tables'\n", tables) + generateAppcode(string(driver), string(conn), "3", string(tables), path.Join(curpath, args[0])) } else { os.Mkdir(path.Join(apppath, "models"), 0755) fmt.Println("create models:", path.Join(apppath, "models")) @@ -621,7 +628,7 @@ func createapi(cmd *Command, args []string) int { writetofile(path.Join(apppath, "main.go"), strings.Replace(apiMaingo, "{{.Appname}}", packpath, -1)) } - return 0 + return 0 } func checkEnv(appname string) (apppath, packpath string, err error) { diff --git a/bee.json b/bee.json index 41dced8..1bd099f 100644 --- a/bee.json +++ b/bee.json @@ -15,7 +15,6 @@ "cmd_args": [], "envs": [], "database": { - "driver": "mysql", - "conn": "root:@tcp(127.0.0.1:3306)/test" + "driver": "mysql" } } \ No newline at end of file diff --git a/conf.go b/conf.go index 5a116b5..9d610d1 100644 --- a/conf.go +++ b/conf.go @@ -38,8 +38,7 @@ var defaultConf = `{ "cmd_args": [], "envs": [], "database": { - "driver": "mysql", - "conn": "root:@tcp(127.0.0.1:3306)/test" + "driver": "mysql" } } ` diff --git a/g.go b/g.go index e221c40..2d7e80e 100644 --- a/g.go +++ b/g.go @@ -50,8 +50,10 @@ bee generate test [routerfile] bee generate appcode [-tables=""] [-driver=mysql] [-conn="root:@tcp(127.0.0.1:3306)/test"] [-level=3] generate appcode based on an existing database -tables: a list of table names separated by ',', default is empty, indicating all tables - -driver: [mysql | postgresql | sqlite], the default is mysql - -conn: the connection string used by the driver, the default is root:@tcp(127.0.0.1:3306)/test + -driver: [mysql | postgres | sqlite], the default is mysql + -conn: the connection string used by the driver. + default for mysql: root:@tcp(127.0.0.1:3306)/test + default for postgres: postgres://postgres:postgres@127.0.0.1:5432/postgres -level: [1 | 2 | 3], 1 = models; 2 = models,controllers; 3 = models,controllers,router `, } @@ -137,7 +139,11 @@ func generateCode(cmd *Command, args []string) int { if conn == "" { conn = docValue(conf.Database.Conn) if conn == "" { - conn = "root:@tcp(127.0.0.1:3306)/test" + if driver == "mysql" { + conn = "root:@tcp(127.0.0.1:3306)/test" + } else if driver == "postgres" { + conn = "postgres://postgres:postgres@127.0.0.1:5432/postgres" + } } } if level == "" { diff --git a/g_appcode.go b/g_appcode.go index d94fbb1..fc5ce83 100644 --- a/g_appcode.go +++ b/g_appcode.go @@ -34,14 +34,36 @@ const ( O_ROUTER ) +// DbTransformer has method to reverse engineer a database schema to restful api code +type DbTransformer interface { + GetTableNames(conn *sql.DB) []string + GetConstraints(conn *sql.DB, table *Table, blackList map[string]bool) + GetColumns(conn *sql.DB, table *Table, blackList map[string]bool) + GetGoDataType(sqlType string) string +} + +// MysqlDB is the MySQL version of DbTransformer +type MysqlDB struct { +} + +// PostgresDB is the PostgreSQL version of DbTransformer +type PostgresDB struct { +} + +// dbDriver maps a DBMS name to its version of DbTransformer +var dbDriver = map[string]DbTransformer{ + "mysql": &MysqlDB{}, + "postgres": &PostgresDB{}, +} + type MvcPath struct { ModelPath string ControllerPath string RouterPath string } -// typeMapping maps a SQL data type to its corresponding Go data type -var typeMapping = map[string]string{ +// typeMapping maps SQL data type to corresponding Go data type +var typeMappingMysql = map[string]string{ "int": "int", // int signed "integer": "int", "tinyint": "int8", @@ -75,6 +97,34 @@ var typeMapping = map[string]string{ "varbinary": "string", } +// typeMappingPostgres maps SQL data type to corresponding Go data type +var typeMappingPostgres = map[string]string{ + "serial": "int", // serial + "big serial": "int64", + "smallint": "int16", // int + "integer": "int", + "bigint": "int64", + "boolean": "bool", // bool + "char": "string", // string + "character": "string", + "character varying": "string", + "varchar": "string", + "text": "string", + "date": "time.Time", // time + "time": "time.Time", + "timestamp": "time.Time", + "timestamp without time zone": "time.Time", + "real": "float32", // float & decimal + "double precision": "float64", + "decimal": "float64", + "numeric": "float64", + "money": "float64", // money + "bytea": "string", // binary + "tsvector": "string", // fulltext + "ARRAY": "string", // array + "USER-DEFINED": "string", // user defined +} + // Table represent a table in a database type Table struct { Name string @@ -240,20 +290,25 @@ func gen(dbms, connStr string, mode byte, selectedTableNames map[string]bool, cu os.Exit(2) } defer db.Close() - ColorLog("[INFO] Analyzing database tables...\n") - tableNames := getTableNames(db) - tables := getTableObjects(tableNames, db) - mvcPath := new(MvcPath) - mvcPath.ModelPath = path.Join(currpath, "models") - mvcPath.ControllerPath = path.Join(currpath, "controllers") - mvcPath.RouterPath = path.Join(currpath, "routers") - createPaths(mode, mvcPath) - pkgPath := getPackagePath(currpath) - writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames) + if trans, ok := dbDriver[dbms]; ok { + ColorLog("[INFO] Analyzing database tables...\n") + tableNames := trans.GetTableNames(db) + tables := getTableObjects(tableNames, db, trans) + mvcPath := new(MvcPath) + mvcPath.ModelPath = path.Join(currpath, "models") + mvcPath.ControllerPath = path.Join(currpath, "controllers") + mvcPath.RouterPath = path.Join(currpath, "routers") + createPaths(mode, mvcPath) + pkgPath := getPackagePath(currpath) + writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames) + } else { + ColorLog("[ERRO] Generating app code from %s database is not supported yet.\n", dbms) + os.Exit(2) + } } // getTables gets a list table names in current database -func getTableNames(db *sql.DB) (tables []string) { +func (*MysqlDB) GetTableNames(db *sql.DB) (tables []string) { rows, err := db.Query("SHOW TABLES") if err != nil { ColorLog("[ERRO] Could not show tables\n") @@ -273,7 +328,7 @@ func getTableNames(db *sql.DB) (tables []string) { } // getTableObjects process each table name -func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) { +func getTableObjects(tableNames []string, db *sql.DB, dbTransformer DbTransformer) (tables []*Table) { // if a table has a composite pk or doesn't have pk, we can't use it yet // these tables will be put into blacklist so that other struct will not // reference it. @@ -284,19 +339,19 @@ func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) { tb := new(Table) tb.Name = tableName tb.Fk = make(map[string]*ForeignKey) - getConstraints(db, tb, blackList) + dbTransformer.GetConstraints(db, tb, blackList) tables = append(tables, tb) } // process columns, ignoring blacklisted tables for _, tb := range tables { - getColumns(db, tb, blackList) + dbTransformer.GetColumns(db, tb, blackList) } return } // getConstraints gets primary key, unique key and foreign keys of a table from information_schema // and fill in Table struct -func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) { +func (*MysqlDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) { rows, err := db.Query( `SELECT c.constraint_type, u.column_name, u.referenced_table_schema, u.referenced_table_name, referenced_column_name, u.ordinal_position @@ -344,7 +399,7 @@ func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) { // getColumns retrieve columns details from information_schema // and fill in the Column struct -func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { +func (mysqlDB *MysqlDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) { // retrieve columns colDefRows, _ := db.Query( `SELECT @@ -367,7 +422,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { // create a column col := new(Column) col.Name = camelCase(colName) - col.Type = getGoDataType(dataType) + col.Type = mysqlDB.GetGoDataType(dataType) // Tag info tag := new(OrmTag) tag.Column = colName @@ -402,7 +457,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { if isSQLSignedIntType(dataType) { sign := extractIntSignness(columnType) if sign == "unsigned" && extra != "auto_increment" { - col.Type = getGoDataType(dataType + " " + sign) + col.Type = mysqlDB.GetGoDataType(dataType + " " + sign) } } if isSQLStringType(dataType) { @@ -432,10 +487,204 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { } } +// getGoDataType maps an SQL data type to Golang data type +func (*MysqlDB) GetGoDataType(sqlType string) (goType string) { + if v, ok := typeMappingMysql[sqlType]; ok { + return v + } else { + ColorLog("[ERRO] data type (%s) not found!\n", sqlType) + os.Exit(2) + } + return goType +} + +// GetTableNames for PostgreSQL +func (*PostgresDB) GetTableNames(db *sql.DB) (tables []string) { + rows, err := db.Query(` + SELECT table_name FROM information_schema.tables + WHERE table_catalog = current_database() and table_schema = 'public'`) + if err != nil { + ColorLog("[ERRO] Could not show tables: %s\n", err) + ColorLog("[HINT] Check your connection string\n") + os.Exit(2) + } + defer rows.Close() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + ColorLog("[ERRO] Could not show tables\n") + os.Exit(2) + } + tables = append(tables, name) + } + return +} + +// GetConstraints for PostgreSQL +func (*PostgresDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) { + rows, err := db.Query( + `SELECT + c.constraint_type, + u.column_name, + cu.table_catalog AS referenced_table_catalog, + cu.table_name AS referenced_table_name, + cu.column_name AS referenced_column_name, + u.ordinal_position + FROM + information_schema.table_constraints c + INNER JOIN + information_schema.key_column_usage u ON c.constraint_name = u.constraint_name + INNER JOIN + information_schema.constraint_column_usage cu ON cu.constraint_name = c.constraint_name + WHERE + c.table_catalog = current_database() AND c.table_schema = 'public' AND c.table_name = $1 + AND u.table_catalog = current_database() AND u.table_schema = 'public' AND u.table_name = $2`, + table.Name, table.Name) // u.position_in_unique_constraint, + if err != nil { + ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for PK/UK/FK information: %s\n", err) + os.Exit(2) + } + for rows.Next() { + var constraintTypeBytes, columnNameBytes, refTableSchemaBytes, refTableNameBytes, refColumnNameBytes, refOrdinalPosBytes []byte + if err := rows.Scan(&constraintTypeBytes, &columnNameBytes, &refTableSchemaBytes, &refTableNameBytes, &refColumnNameBytes, &refOrdinalPosBytes); err != nil { + ColorLog("[ERRO] Could not read INFORMATION_SCHEMA for PK/UK/FK information\n") + os.Exit(2) + } + constraintType, columnName, refTableSchema, refTableName, refColumnName, refOrdinalPos := + string(constraintTypeBytes), string(columnNameBytes), string(refTableSchemaBytes), + string(refTableNameBytes), string(refColumnNameBytes), string(refOrdinalPosBytes) + if constraintType == "PRIMARY KEY" { + if refOrdinalPos == "1" { + table.Pk = columnName + } else { + table.Pk = "" + // add table to blacklist so that other struct will not reference it, because we are not + // registering blacklisted tables + blackList[table.Name] = true + } + } else if constraintType == "UNIQUE" { + table.Uk = append(table.Uk, columnName) + } else if constraintType == "FOREIGN KEY" { + fk := new(ForeignKey) + fk.Name = columnName + fk.RefSchema = refTableSchema + fk.RefTable = refTableName + fk.RefColumn = refColumnName + table.Fk[columnName] = fk + } + } +} + +// GetColumns for PostgreSQL +func (postgresDB *PostgresDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) { + // retrieve columns + colDefRows, _ := db.Query( + `SELECT + column_name, + data_type, + data_type || + CASE + WHEN data_type = 'character' THEN '('||character_maximum_length||')' + WHEN data_type = 'numeric' THEN '(' || numeric_precision || ',' || numeric_scale ||')' + ELSE '' + END AS column_type, + is_nullable, + column_default, + '' AS extra + FROM + information_schema.columns + WHERE + table_catalog = current_database() AND table_schema = 'public' AND table_name = $1`, + table.Name) + defer colDefRows.Close() + for colDefRows.Next() { + // datatype as bytes so that SQL values can be retrieved + var colNameBytes, dataTypeBytes, columnTypeBytes, isNullableBytes, columnDefaultBytes, extraBytes []byte + if err := colDefRows.Scan(&colNameBytes, &dataTypeBytes, &columnTypeBytes, &isNullableBytes, &columnDefaultBytes, &extraBytes); err != nil { + ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for column information\n") + os.Exit(2) + } + colName, dataType, columnType, isNullable, columnDefault, extra := + string(colNameBytes), string(dataTypeBytes), string(columnTypeBytes), string(isNullableBytes), string(columnDefaultBytes), string(extraBytes) + // create a column + col := new(Column) + col.Name = camelCase(colName) + col.Type = postgresDB.GetGoDataType(dataType) + // Tag info + tag := new(OrmTag) + tag.Column = colName + if table.Pk == colName { + col.Name = "Id" + col.Type = "int" + if extra == "auto_increment" { + tag.Auto = true + } else { + tag.Pk = true + } + } else { + fkCol, isFk := table.Fk[colName] + isBl := false + if isFk { + _, isBl = blackList[fkCol.RefTable] + } + // check if the current column is a foreign key + if isFk && !isBl { + tag.RelFk = true + refStructName := fkCol.RefTable + col.Name = camelCase(colName) + col.Type = "*" + camelCase(refStructName) + } else { + // if the name of column is Id, and it's not primary key + if colName == "id" { + col.Name = "Id_RENAME" + } + if isNullable == "YES" { + tag.Null = true + } + if isSQLStringType(dataType) { + tag.Size = extractColSize(columnType) + } + if isSQLTemporalType(dataType) || strings.HasPrefix(dataType, "timestamp") { + tag.Type = dataType + //check auto_now, auto_now_add + if columnDefault == "CURRENT_TIMESTAMP" && extra == "on update CURRENT_TIMESTAMP" { + tag.AutoNow = true + } else if columnDefault == "CURRENT_TIMESTAMP" { + tag.AutoNowAdd = true + } + // need to import time package + table.ImportTimePkg = true + } + if isSQLDecimal(dataType) { + tag.Digits, tag.Decimals = extractDecimal(columnType) + } + if isSQLBinaryType(dataType) { + tag.Size = extractColSize(columnType) + } + } + } + col.Tag = tag + table.Columns = append(table.Columns, col) + } +} +func (*PostgresDB) GetGoDataType(sqlType string) (goType string) { + if v, ok := typeMappingPostgres[sqlType]; ok { + return v + } else { + ColorLog("[ERRO] data type (%s) not found!\n", sqlType) + os.Exit(2) + } + return goType +} + // deleteAndRecreatePaths removes several directories completely func createPaths(mode byte, paths *MvcPath) { if (mode & O_MODEL) == O_MODEL { - os.Mkdir(paths.ModelPath, 0777) + err := os.Mkdir(paths.ModelPath, 0777) + if err != nil { + ColorLog("[ERRO]", err) + os.Exit(2) + } } if (mode & O_CONTROLLER) == O_CONTROLLER { os.Mkdir(paths.ControllerPath, 0777) @@ -639,17 +888,6 @@ func camelCase(in string) string { return strings.Join(tokens, "") } -// getGoDataType maps an SQL data type to Golang data type -func getGoDataType(sqlType string) (goType string) { - if v, ok := typeMapping[sqlType]; ok { - return v - } else { - fmt.Println("Error:", sqlType, "not found!") - os.Exit(1) - } - return goType -} - func isSQLTemporalType(t string) bool { return t == "date" || t == "datetime" || t == "timestamp" }