From b14868c908e153ef9e2318d9a5b40d9d3bf7c355 Mon Sep 17 00:00:00 2001 From: ZhengYang Date: Fri, 22 Aug 2014 15:50:13 +0800 Subject: [PATCH] code gen for postgres --- g_appcode.go | 314 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 264 insertions(+), 50 deletions(-) diff --git a/g_appcode.go b/g_appcode.go index 107be78..fc5ce83 100644 --- a/g_appcode.go +++ b/g_appcode.go @@ -34,6 +34,28 @@ const ( O_ROUTER ) +// DbTransformer has method to reverse engineer a database schema to restful api code +type DbTransformer interface { + GetTableNames(conn *sql.DB) []string + GetConstraints(conn *sql.DB, table *Table, blackList map[string]bool) + GetColumns(conn *sql.DB, table *Table, blackList map[string]bool) + GetGoDataType(sqlType string) string +} + +// MysqlDB is the MySQL version of DbTransformer +type MysqlDB struct { +} + +// PostgresDB is the PostgreSQL version of DbTransformer +type PostgresDB struct { +} + +// dbDriver maps a DBMS name to its version of DbTransformer +var dbDriver = map[string]DbTransformer{ + "mysql": &MysqlDB{}, + "postgres": &PostgresDB{}, +} + type MvcPath struct { ModelPath string ControllerPath string @@ -77,26 +99,30 @@ var typeMappingMysql = map[string]string{ // typeMappingPostgres maps SQL data type to corresponding Go data type var typeMappingPostgres = map[string]string{ - "serial": "int", // serial - "big serial": "int64", - "smallint": "int16", // int - "integer": "int", - "bigint": "int64", - "boolean": "bool", // bool - "char": "string", // string - "character": "string", - "character varying": "string", - "varchar": "string", - "text": "string", - "date": "time.Time", // time - "time": "time.Time", - "timestamp": "time.Time", - "real": "float32", // float & decimal - "double precision": "float64", - "decimal": "float64", - "numeric": "float64", - "money": "float64", // money - "bytea": "string", // binary + "serial": "int", // serial + "big serial": "int64", + "smallint": "int16", // int + "integer": "int", + "bigint": "int64", + "boolean": "bool", // bool + "char": "string", // string + "character": "string", + "character varying": "string", + "varchar": "string", + "text": "string", + "date": "time.Time", // time + "time": "time.Time", + "timestamp": "time.Time", + "timestamp without time zone": "time.Time", + "real": "float32", // float & decimal + "double precision": "float64", + "decimal": "float64", + "numeric": "float64", + "money": "float64", // money + "bytea": "string", // binary + "tsvector": "string", // fulltext + "ARRAY": "string", // array + "USER-DEFINED": "string", // user defined } // Table represent a table in a database @@ -264,20 +290,25 @@ func gen(dbms, connStr string, mode byte, selectedTableNames map[string]bool, cu os.Exit(2) } defer db.Close() - ColorLog("[INFO] Analyzing database tables...\n") - tableNames := getTableNames(db) - tables := getTableObjects(tableNames, db) - mvcPath := new(MvcPath) - mvcPath.ModelPath = path.Join(currpath, "models") - mvcPath.ControllerPath = path.Join(currpath, "controllers") - mvcPath.RouterPath = path.Join(currpath, "routers") - createPaths(mode, mvcPath) - pkgPath := getPackagePath(currpath) - writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames) + if trans, ok := dbDriver[dbms]; ok { + ColorLog("[INFO] Analyzing database tables...\n") + tableNames := trans.GetTableNames(db) + tables := getTableObjects(tableNames, db, trans) + mvcPath := new(MvcPath) + mvcPath.ModelPath = path.Join(currpath, "models") + mvcPath.ControllerPath = path.Join(currpath, "controllers") + mvcPath.RouterPath = path.Join(currpath, "routers") + createPaths(mode, mvcPath) + pkgPath := getPackagePath(currpath) + writeSourceFiles(pkgPath, tables, mode, mvcPath, selectedTableNames) + } else { + ColorLog("[ERRO] Generating app code from %s database is not supported yet.\n", dbms) + os.Exit(2) + } } // getTables gets a list table names in current database -func getTableNames(db *sql.DB) (tables []string) { +func (*MysqlDB) GetTableNames(db *sql.DB) (tables []string) { rows, err := db.Query("SHOW TABLES") if err != nil { ColorLog("[ERRO] Could not show tables\n") @@ -297,7 +328,7 @@ func getTableNames(db *sql.DB) (tables []string) { } // getTableObjects process each table name -func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) { +func getTableObjects(tableNames []string, db *sql.DB, dbTransformer DbTransformer) (tables []*Table) { // if a table has a composite pk or doesn't have pk, we can't use it yet // these tables will be put into blacklist so that other struct will not // reference it. @@ -308,19 +339,19 @@ func getTableObjects(tableNames []string, db *sql.DB) (tables []*Table) { tb := new(Table) tb.Name = tableName tb.Fk = make(map[string]*ForeignKey) - getConstraints(db, tb, blackList) + dbTransformer.GetConstraints(db, tb, blackList) tables = append(tables, tb) } // process columns, ignoring blacklisted tables for _, tb := range tables { - getColumns(db, tb, blackList) + dbTransformer.GetColumns(db, tb, blackList) } return } // getConstraints gets primary key, unique key and foreign keys of a table from information_schema // and fill in Table struct -func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) { +func (*MysqlDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) { rows, err := db.Query( `SELECT c.constraint_type, u.column_name, u.referenced_table_schema, u.referenced_table_name, referenced_column_name, u.ordinal_position @@ -368,7 +399,7 @@ func getConstraints(db *sql.DB, table *Table, blackList map[string]bool) { // getColumns retrieve columns details from information_schema // and fill in the Column struct -func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { +func (mysqlDB *MysqlDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) { // retrieve columns colDefRows, _ := db.Query( `SELECT @@ -391,7 +422,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { // create a column col := new(Column) col.Name = camelCase(colName) - col.Type = getGoDataType(dataType) + col.Type = mysqlDB.GetGoDataType(dataType) // Tag info tag := new(OrmTag) tag.Column = colName @@ -426,7 +457,7 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { if isSQLSignedIntType(dataType) { sign := extractIntSignness(columnType) if sign == "unsigned" && extra != "auto_increment" { - col.Type = getGoDataType(dataType + " " + sign) + col.Type = mysqlDB.GetGoDataType(dataType + " " + sign) } } if isSQLStringType(dataType) { @@ -456,10 +487,204 @@ func getColumns(db *sql.DB, table *Table, blackList map[string]bool) { } } +// getGoDataType maps an SQL data type to Golang data type +func (*MysqlDB) GetGoDataType(sqlType string) (goType string) { + if v, ok := typeMappingMysql[sqlType]; ok { + return v + } else { + ColorLog("[ERRO] data type (%s) not found!\n", sqlType) + os.Exit(2) + } + return goType +} + +// GetTableNames for PostgreSQL +func (*PostgresDB) GetTableNames(db *sql.DB) (tables []string) { + rows, err := db.Query(` + SELECT table_name FROM information_schema.tables + WHERE table_catalog = current_database() and table_schema = 'public'`) + if err != nil { + ColorLog("[ERRO] Could not show tables: %s\n", err) + ColorLog("[HINT] Check your connection string\n") + os.Exit(2) + } + defer rows.Close() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + ColorLog("[ERRO] Could not show tables\n") + os.Exit(2) + } + tables = append(tables, name) + } + return +} + +// GetConstraints for PostgreSQL +func (*PostgresDB) GetConstraints(db *sql.DB, table *Table, blackList map[string]bool) { + rows, err := db.Query( + `SELECT + c.constraint_type, + u.column_name, + cu.table_catalog AS referenced_table_catalog, + cu.table_name AS referenced_table_name, + cu.column_name AS referenced_column_name, + u.ordinal_position + FROM + information_schema.table_constraints c + INNER JOIN + information_schema.key_column_usage u ON c.constraint_name = u.constraint_name + INNER JOIN + information_schema.constraint_column_usage cu ON cu.constraint_name = c.constraint_name + WHERE + c.table_catalog = current_database() AND c.table_schema = 'public' AND c.table_name = $1 + AND u.table_catalog = current_database() AND u.table_schema = 'public' AND u.table_name = $2`, + table.Name, table.Name) // u.position_in_unique_constraint, + if err != nil { + ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for PK/UK/FK information: %s\n", err) + os.Exit(2) + } + for rows.Next() { + var constraintTypeBytes, columnNameBytes, refTableSchemaBytes, refTableNameBytes, refColumnNameBytes, refOrdinalPosBytes []byte + if err := rows.Scan(&constraintTypeBytes, &columnNameBytes, &refTableSchemaBytes, &refTableNameBytes, &refColumnNameBytes, &refOrdinalPosBytes); err != nil { + ColorLog("[ERRO] Could not read INFORMATION_SCHEMA for PK/UK/FK information\n") + os.Exit(2) + } + constraintType, columnName, refTableSchema, refTableName, refColumnName, refOrdinalPos := + string(constraintTypeBytes), string(columnNameBytes), string(refTableSchemaBytes), + string(refTableNameBytes), string(refColumnNameBytes), string(refOrdinalPosBytes) + if constraintType == "PRIMARY KEY" { + if refOrdinalPos == "1" { + table.Pk = columnName + } else { + table.Pk = "" + // add table to blacklist so that other struct will not reference it, because we are not + // registering blacklisted tables + blackList[table.Name] = true + } + } else if constraintType == "UNIQUE" { + table.Uk = append(table.Uk, columnName) + } else if constraintType == "FOREIGN KEY" { + fk := new(ForeignKey) + fk.Name = columnName + fk.RefSchema = refTableSchema + fk.RefTable = refTableName + fk.RefColumn = refColumnName + table.Fk[columnName] = fk + } + } +} + +// GetColumns for PostgreSQL +func (postgresDB *PostgresDB) GetColumns(db *sql.DB, table *Table, blackList map[string]bool) { + // retrieve columns + colDefRows, _ := db.Query( + `SELECT + column_name, + data_type, + data_type || + CASE + WHEN data_type = 'character' THEN '('||character_maximum_length||')' + WHEN data_type = 'numeric' THEN '(' || numeric_precision || ',' || numeric_scale ||')' + ELSE '' + END AS column_type, + is_nullable, + column_default, + '' AS extra + FROM + information_schema.columns + WHERE + table_catalog = current_database() AND table_schema = 'public' AND table_name = $1`, + table.Name) + defer colDefRows.Close() + for colDefRows.Next() { + // datatype as bytes so that SQL values can be retrieved + var colNameBytes, dataTypeBytes, columnTypeBytes, isNullableBytes, columnDefaultBytes, extraBytes []byte + if err := colDefRows.Scan(&colNameBytes, &dataTypeBytes, &columnTypeBytes, &isNullableBytes, &columnDefaultBytes, &extraBytes); err != nil { + ColorLog("[ERRO] Could not query INFORMATION_SCHEMA for column information\n") + os.Exit(2) + } + colName, dataType, columnType, isNullable, columnDefault, extra := + string(colNameBytes), string(dataTypeBytes), string(columnTypeBytes), string(isNullableBytes), string(columnDefaultBytes), string(extraBytes) + // create a column + col := new(Column) + col.Name = camelCase(colName) + col.Type = postgresDB.GetGoDataType(dataType) + // Tag info + tag := new(OrmTag) + tag.Column = colName + if table.Pk == colName { + col.Name = "Id" + col.Type = "int" + if extra == "auto_increment" { + tag.Auto = true + } else { + tag.Pk = true + } + } else { + fkCol, isFk := table.Fk[colName] + isBl := false + if isFk { + _, isBl = blackList[fkCol.RefTable] + } + // check if the current column is a foreign key + if isFk && !isBl { + tag.RelFk = true + refStructName := fkCol.RefTable + col.Name = camelCase(colName) + col.Type = "*" + camelCase(refStructName) + } else { + // if the name of column is Id, and it's not primary key + if colName == "id" { + col.Name = "Id_RENAME" + } + if isNullable == "YES" { + tag.Null = true + } + if isSQLStringType(dataType) { + tag.Size = extractColSize(columnType) + } + if isSQLTemporalType(dataType) || strings.HasPrefix(dataType, "timestamp") { + tag.Type = dataType + //check auto_now, auto_now_add + if columnDefault == "CURRENT_TIMESTAMP" && extra == "on update CURRENT_TIMESTAMP" { + tag.AutoNow = true + } else if columnDefault == "CURRENT_TIMESTAMP" { + tag.AutoNowAdd = true + } + // need to import time package + table.ImportTimePkg = true + } + if isSQLDecimal(dataType) { + tag.Digits, tag.Decimals = extractDecimal(columnType) + } + if isSQLBinaryType(dataType) { + tag.Size = extractColSize(columnType) + } + } + } + col.Tag = tag + table.Columns = append(table.Columns, col) + } +} +func (*PostgresDB) GetGoDataType(sqlType string) (goType string) { + if v, ok := typeMappingPostgres[sqlType]; ok { + return v + } else { + ColorLog("[ERRO] data type (%s) not found!\n", sqlType) + os.Exit(2) + } + return goType +} + // deleteAndRecreatePaths removes several directories completely func createPaths(mode byte, paths *MvcPath) { if (mode & O_MODEL) == O_MODEL { - os.Mkdir(paths.ModelPath, 0777) + err := os.Mkdir(paths.ModelPath, 0777) + if err != nil { + ColorLog("[ERRO]", err) + os.Exit(2) + } } if (mode & O_CONTROLLER) == O_CONTROLLER { os.Mkdir(paths.ControllerPath, 0777) @@ -663,17 +888,6 @@ func camelCase(in string) string { return strings.Join(tokens, "") } -// getGoDataType maps an SQL data type to Golang data type -func getGoDataType(sqlType string) (goType string) { - if v, ok := typeMappingMysql[sqlType]; ok { - return v - } else { - fmt.Println("Error:", sqlType, "not found!") - os.Exit(1) - } - return goType -} - func isSQLTemporalType(t string) bool { return t == "date" || t == "datetime" || t == "timestamp" }