1
0
mirror of https://github.com/beego/bee.git synced 2025-01-10 11:47:13 +00:00

code gen for postgres

This commit is contained in:
ZhengYang 2014-08-22 15:50:13 +08:00
parent 4dc2d67bd2
commit b14868c908

@ -34,6 +34,28 @@ const (
O_ROUTER
)
// DbTransformer has method to reverse engineer a database schema to restful api code
type DbTransformer interface {
GetTableNames(conn *sql.DB) []string
GetConstraints(conn *sql.DB, table *Table, blackList map[string]bool)
GetColumns(conn *sql.DB, table *Table, blackList map[string]bool)
GetGoDataType(sqlType string) string
}
// MysqlDB is the MySQL version of DbTransformer
type MysqlDB struct {
}
// PostgresDB is the PostgreSQL version of DbTransformer
type PostgresDB struct {
}
// dbDriver maps a DBMS name to its version of DbTransformer
var dbDriver = map[string]DbTransformer{
"mysql": &MysqlDB{},
"postgres": &PostgresDB{},
}
type MvcPath struct {
ModelPath string
ControllerPath string
@ -77,26 +99,30 @@ var typeMappingMysql = map[string]string{
// typeMappingPostgres maps SQL data type to corresponding Go data type
var typeMappingPostgres = map[string]string{
"serial": "int", // serial
"big serial": "int64",
"smallint": "int16", // int
"integer": "int",
"bigint": "int64",
"boolean": "bool", // bool
"char": "string", // string
"character": "string",
"character varying": "string",
"varchar": "string",
"text": "string",
"date": "time.Time", // time
"time": "time.Time",
"timestamp": "time.Time",
"real": "float32", // float & decimal
"double precision": "float64",
"decimal": "float64",
"numeric": "float64",
"money": "float64", // money
"bytea": "string", // binary
"serial": "int", // serial
"big serial": "int64",
"smallint": "int16", // int
"integer": "int",
"bigint": "int64",
"boolean": "bool", // bool
"char": "string", // string
"character": "string",
"character varying": "string",
"varchar": "string",
"text": "string",
"date": "time.Time", // time
"time": "time.Time",
"timestamp": "time.Time",
"timestamp without time zone": "time.Time",
"real": "float32", // float & decimal
"double precision": "float64",
"decimal": "float64",
"numeric": "float64",
"money": "float64", // money
"bytea": "string", // binary
"tsvector": "string", // fulltext
"ARRAY": "string", // array
"USER-DEFINED": "string", // user defined
}
// Table represent a table in a database
@ -264,20 +290,25 @@ func gen(dbms, connStr string, mode byte, selectedTableNames map[string]bool, cu
os.Exit(2)
}
defer db.Close()
ColorLog("[INFO] Analyzing database tables...\n")
tableNames := getTableNames(db)
tables := getTableObjects(tableNames, db)
mvcPath := new(MvcPath)
mvcPath.ModelPath = path.Join(currpath, "models")
mvcPath.ControllerPath = path.Join(currpath, "controllers")
mvcPath.RouterPath = path.Join(currpath, "routers")
createPaths(mode, mvcPath)
pkgPath := getPackagePath(currpath)
writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames)
if trans, ok := dbDriver[dbms]; ok {
ColorLog("[INFO] Analyzing database tables...\n")
tableNames := trans.GetTableNames(db)
tables := getTableObjects(tableNames, db, trans)
mvcPath := new(MvcPath)
mvcPath.ModelPath = path.Join(currpath, "models")
mvcPath.ControllerPath = path.Join(currpath, "controllers")
mvcPath.RouterPath = path.Join(currpath, "routers")
createPaths(mode, mvcPath)
pkgPath := getPackagePath(currpath)
writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames)
} else {
ColorLog("[ERRO] Generating app code from %s database is not supported yet.\n", dbms)
os.Exit(2)
}
}
// getTables gets a list table names in current database
func getTableNames(db *sql.DB) (tables []string) {
func (*MysqlDB) GetTableNames(db *sql.DB) (tables []string) {
rows, err := db.Query("SHOW TABLES")
if err != nil {
ColorLog("[ERRO] Could not show tables\n")
@ -297,7 +328,7 @@ func getTableNames(db *sql.DB) (tables []string) {
}
// getTableObjects process each table name
func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) {
func getTableObjects(tableNames []string, db *sql.DB, dbTransformer DbTransformer) (tables []*Table) {
// if a table has a composite pk or doesn't have pk, we can't use it yet
// these tables will be put into blacklist so that other struct will not
// reference it.
@ -308,19 +339,19 @@ func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) {
tb := new(Table)
tb.Name = tableName
tb.Fk = make(map[string]*ForeignKey)
getConstraints(db, tb, blackList)
dbTransformer.GetConstraints(db, tb, blackList)
tables = append(tables, tb)
}
// process columns, ignoring blacklisted tables
for _, tb := range tables {
getColumns(db, tb, blackList)
dbTransformer.GetColumns(db, tb, blackList)
}
return
}
// getConstraints gets primary key, unique key and foreign keys of a table from information_schema
// and fill in Table struct
func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) {
func (*MysqlDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) {
rows, err := db.Query(
`SELECT
c.constraint_type, u.column_name, u.referenced_table_schema, u.referenced_table_name, referenced_column_name, u.ordinal_position
@ -368,7 +399,7 @@ func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) {
// getColumns retrieve columns details from information_schema
// and fill in the Column struct
func getColumns(db *sql.DB, table *Table, blackList map[string]bool) {
func (mysqlDB *MysqlDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) {
// retrieve columns
colDefRows, _ := db.Query(
`SELECT
@ -391,7 +422,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) {
// create a column
col := new(Column)
col.Name = camelCase(colName)
col.Type = getGoDataType(dataType)
col.Type = mysqlDB.GetGoDataType(dataType)
// Tag info
tag := new(OrmTag)
tag.Column = colName
@ -426,7 +457,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) {
if isSQLSignedIntType(dataType) {
sign := extractIntSignness(columnType)
if sign == "unsigned" && extra != "auto_increment" {
col.Type = getGoDataType(dataType + " " + sign)
col.Type = mysqlDB.GetGoDataType(dataType + " " + sign)
}
}
if isSQLStringType(dataType) {
@ -456,10 +487,204 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) {
}
}
// getGoDataType maps an SQL data type to Golang data type
func (*MysqlDB) GetGoDataType(sqlType string) (goType string) {
if v, ok := typeMappingMysql[sqlType]; ok {
return v
} else {
ColorLog("[ERRO] data type (%s) not found!\n", sqlType)
os.Exit(2)
}
return goType
}
// GetTableNames for PostgreSQL
func (*PostgresDB) GetTableNames(db *sql.DB) (tables []string) {
rows, err := db.Query(`
SELECT table_name FROM information_schema.tables
WHERE table_catalog = current_database() and table_schema = 'public'`)
if err != nil {
ColorLog("[ERRO] Could not show tables: %s\n", err)
ColorLog("[HINT] Check your connection string\n")
os.Exit(2)
}
defer rows.Close()
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
ColorLog("[ERRO] Could not show tables\n")
os.Exit(2)
}
tables = append(tables, name)
}
return
}
// GetConstraints for PostgreSQL
func (*PostgresDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) {
rows, err := db.Query(
`SELECT
c.constraint_type,
u.column_name,
cu.table_catalog AS referenced_table_catalog,
cu.table_name AS referenced_table_name,
cu.column_name AS referenced_column_name,
u.ordinal_position
FROM
information_schema.table_constraints c
INNER JOIN
information_schema.key_column_usage u ON c.constraint_name = u.constraint_name
INNER JOIN
information_schema.constraint_column_usage cu ON cu.constraint_name = c.constraint_name
WHERE
c.table_catalog = current_database() AND c.table_schema = 'public' AND c.table_name = $1
AND u.table_catalog = current_database() AND u.table_schema = 'public' AND u.table_name = $2`,
table.Name, table.Name) // u.position_in_unique_constraint,
if err != nil {
ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for PK/UK/FK information: %s\n", err)
os.Exit(2)
}
for rows.Next() {
var constraintTypeBytes, columnNameBytes, refTableSchemaBytes, refTableNameBytes, refColumnNameBytes, refOrdinalPosBytes []byte
if err := rows.Scan(&constraintTypeBytes, &columnNameBytes, &refTableSchemaBytes, &refTableNameBytes, &refColumnNameBytes, &refOrdinalPosBytes); err != nil {
ColorLog("[ERRO] Could not read INFORMATION_SCHEMA for PK/UK/FK information\n")
os.Exit(2)
}
constraintType, columnName, refTableSchema, refTableName, refColumnName, refOrdinalPos :=
string(constraintTypeBytes), string(columnNameBytes), string(refTableSchemaBytes),
string(refTableNameBytes), string(refColumnNameBytes), string(refOrdinalPosBytes)
if constraintType == "PRIMARY KEY" {
if refOrdinalPos == "1" {
table.Pk = columnName
} else {
table.Pk = ""
// add table to blacklist so that other struct will not reference it, because we are not
// registering blacklisted tables
blackList[table.Name] = true
}
} else if constraintType == "UNIQUE" {
table.Uk = append(table.Uk, columnName)
} else if constraintType == "FOREIGN KEY" {
fk := new(ForeignKey)
fk.Name = columnName
fk.RefSchema = refTableSchema
fk.RefTable = refTableName
fk.RefColumn = refColumnName
table.Fk[columnName] = fk
}
}
}
// GetColumns for PostgreSQL
func (postgresDB *PostgresDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) {
// retrieve columns
colDefRows, _ := db.Query(
`SELECT
column_name,
data_type,
data_type ||
CASE
WHEN data_type = 'character' THEN '('||character_maximum_length||')'
WHEN data_type = 'numeric' THEN '(' || numeric_precision || ',' || numeric_scale ||')'
ELSE ''
END AS column_type,
is_nullable,
column_default,
'' AS extra
FROM
information_schema.columns
WHERE
table_catalog = current_database() AND table_schema = 'public' AND table_name = $1`,
table.Name)
defer colDefRows.Close()
for colDefRows.Next() {
// datatype as bytes so that SQL <null> values can be retrieved
var colNameBytes, dataTypeBytes, columnTypeBytes, isNullableBytes, columnDefaultBytes, extraBytes []byte
if err := colDefRows.Scan(&colNameBytes, &dataTypeBytes, &columnTypeBytes, &isNullableBytes, &columnDefaultBytes, &extraBytes); err != nil {
ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for column information\n")
os.Exit(2)
}
colName, dataType, columnType, isNullable, columnDefault, extra :=
string(colNameBytes), string(dataTypeBytes), string(columnTypeBytes), string(isNullableBytes), string(columnDefaultBytes), string(extraBytes)
// create a column
col := new(Column)
col.Name = camelCase(colName)
col.Type = postgresDB.GetGoDataType(dataType)
// Tag info
tag := new(OrmTag)
tag.Column = colName
if table.Pk == colName {
col.Name = "Id"
col.Type = "int"
if extra == "auto_increment" {
tag.Auto = true
} else {
tag.Pk = true
}
} else {
fkCol, isFk := table.Fk[colName]
isBl := false
if isFk {
_, isBl = blackList[fkCol.RefTable]
}
// check if the current column is a foreign key
if isFk && !isBl {
tag.RelFk = true
refStructName := fkCol.RefTable
col.Name = camelCase(colName)
col.Type = "*" + camelCase(refStructName)
} else {
// if the name of column is Id, and it's not primary key
if colName == "id" {
col.Name = "Id_RENAME"
}
if isNullable == "YES" {
tag.Null = true
}
if isSQLStringType(dataType) {
tag.Size = extractColSize(columnType)
}
if isSQLTemporalType(dataType) || strings.HasPrefix(dataType, "timestamp") {
tag.Type = dataType
//check auto_now, auto_now_add
if columnDefault == "CURRENT_TIMESTAMP" && extra == "on update CURRENT_TIMESTAMP" {
tag.AutoNow = true
} else if columnDefault == "CURRENT_TIMESTAMP" {
tag.AutoNowAdd = true
}
// need to import time package
table.ImportTimePkg = true
}
if isSQLDecimal(dataType) {
tag.Digits, tag.Decimals = extractDecimal(columnType)
}
if isSQLBinaryType(dataType) {
tag.Size = extractColSize(columnType)
}
}
}
col.Tag = tag
table.Columns = append(table.Columns, col)
}
}
func (*PostgresDB) GetGoDataType(sqlType string) (goType string) {
if v, ok := typeMappingPostgres[sqlType]; ok {
return v
} else {
ColorLog("[ERRO] data type (%s) not found!\n", sqlType)
os.Exit(2)
}
return goType
}
// deleteAndRecreatePaths removes several directories completely
func createPaths(mode byte, paths *MvcPath) {
if (mode & O_MODEL) == O_MODEL {
os.Mkdir(paths.ModelPath, 0777)
err := os.Mkdir(paths.ModelPath, 0777)
if err != nil {
ColorLog("[ERRO]", err)
os.Exit(2)
}
}
if (mode & O_CONTROLLER) == O_CONTROLLER {
os.Mkdir(paths.ControllerPath, 0777)
@ -663,17 +888,6 @@ func camelCase(in string) string {
return strings.Join(tokens, "")
}
// getGoDataType maps an SQL data type to Golang data type
func getGoDataType(sqlType string) (goType string) {
if v, ok := typeMappingMysql[sqlType]; ok {
return v
} else {
fmt.Println("Error:", sqlType, "not found!")
os.Exit(1)
}
return goType
}
func isSQLTemporalType(t string) bool {
return t == "date" || t == "datetime" || t == "timestamp"
}