From 49bbca0ce355ac018c654af1d7328684ceb2313f Mon Sep 17 00:00:00 2001 From: slene Date: Tue, 27 Aug 2013 12:33:27 +0800 Subject: [PATCH] orm Improve syncdb --- orm/cmd.go | 130 +++++++++++++++++++++++++++++++++++++++----- orm/cmd_utils.go | 131 +++++++++++++++++++++++++++------------------ orm/db.go | 58 ++++++++++++++++++++ orm/db_mysql.go | 21 ++++++++ orm/db_postgres.go | 16 ++++++ orm/db_sqlite.go | 46 ++++++++++++++++ orm/orm_test.go | 24 +-------- orm/types.go | 5 ++ 8 files changed, 344 insertions(+), 87 deletions(-) diff --git a/orm/cmd.go b/orm/cmd.go index 59f0833e..97545da4 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -9,7 +9,7 @@ import ( type commander interface { Parse([]string) - Run() + Run() error } var ( @@ -59,9 +59,11 @@ func RunCommand() { } type commandSyncDb struct { - al *alias - force bool - verbose bool + al *alias + force bool + verbose bool + noInfo bool + rtOnError bool } func (d *commandSyncDb) Parse(args []string) { @@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) { d.al = getDbAlias(name) } -func (d *commandSyncDb) Run() { +func (d *commandSyncDb) Run() error { var drops []string if d.force { drops = getDbDropSql(d.al) @@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() { if d.force { for i, mi := range modelCache.allOrdered() { query := drops[i] - _, err := db.Exec(query) - result := "" - if err != nil { - result = err.Error() + if !d.noInfo { + fmt.Printf("drop table `%s`\n", mi.table) } - fmt.Printf("drop table `%s` %s\n", mi.table, result) + _, err := db.Exec(query) if d.verbose { fmt.Printf(" %s\n\n", query) } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } } } sqls, indexes := getDbCreateSql(d.al) + tables, err := d.al.DbBaser.GetTables(db) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + for i, mi := range modelCache.allOrdered() { - fmt.Printf("create table `%s` \n", mi.table) + if tables[mi.table] { + if !d.noInfo { + fmt.Printf("table `%s` already exists, skip\n", mi.table) + } + + var fields []*fieldInfo + columns, err := d.al.DbBaser.GetColumns(db, mi.table) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for _, fi := range mi.fields.fieldsDB { + if _, ok := columns[fi.column]; ok == false { + fields = append(fields, fi) + } + } + + for _, fi := range fields { + query := getColumnAddQuery(d.al, fi) + + if !d.noInfo { + fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + } + + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + + for _, idx := range indexes[mi.table] { + if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false { + if !d.noInfo { + fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) + } + + query := idx.Sql + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + continue + } + + if !d.noInfo { + fmt.Printf("create table `%s` \n", mi.table) + } queries := []string{sqls[i]} - queries = append(queries, indexes[mi.table]...) + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.Sql) + } for _, query := range queries { _, err := db.Exec(query) @@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() { fmt.Println(query) } if err != nil { + if d.rtOnError { + return err + } fmt.Printf(" %s\n", err.Error()) } } @@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() { fmt.Println("") } } + + return nil } type commandSqlAll struct { @@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) { d.al = getDbAlias(name) } -func (d *commandSqlAll) Run() { +func (d *commandSqlAll) Run() error { sqls, indexes := getDbCreateSql(d.al) var all []string for i, mi := range modelCache.allOrdered() { queries := []string{sqls[i]} - queries = append(queries, indexes[mi.table]...) + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.Sql) + } sql := strings.Join(queries, "\n") all = append(all, sql) } fmt.Println(strings.Join(all, "\n\n")) + + return nil } func init() { commands["syncdb"] = new(commandSyncDb) commands["sqlall"] = new(commandSqlAll) } + +func RunSyncdb(name string, force bool, verbose bool) error { + BootStrap() + + al := getDbAlias(name) + cmd := new(commandSyncDb) + cmd.al = al + cmd.force = force + cmd.noInfo = !verbose + cmd.verbose = verbose + cmd.rtOnError = true + return cmd.Run() +} diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 2b722619..9ddcbea9 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -6,6 +6,12 @@ import ( "strings" ) +type dbIndex struct { + Table string + Name string + Sql string +} + func getDbAlias(name string) *alias { if al, ok := dataBaseCache.get(name); ok { return al @@ -31,7 +37,71 @@ func getDbDropSql(al *alias) (sqls []string) { return sqls } -func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) { +func getColumnTyp(al *alias, fi *fieldInfo) (col string) { + T := al.DbBaser.DbTypes() + fieldType := fi.fieldType + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeCharField: + col = fmt.Sprintf(T["string"], fi.size) + case TypeTextField: + col = T["string-text"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + col = T["time.Time"] + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if al.Driver == DR_Sqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if strings.Index(s, "%d") == -1 { + col = s + } else { + col = fmt.Sprintf(s, fi.digits, fi.decimals) + } + case RelForeignKey, RelOneToOne: + fieldType = fi.relModelInfo.fields.pk.fieldType + goto checkColumn + } + + return +} + +func getColumnAddQuery(al *alias, fi *fieldInfo) string { + Q := al.DbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if fi.null == false { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) +} + +func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") os.Exit(2) @@ -41,7 +111,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) T := al.DbBaser.DbTypes() sep := fmt.Sprintf("%s, %s", Q, Q) - tableIndexes = make(map[string][]string) + tableIndexes = make(map[string][]dbIndex) for _, mi := range modelCache.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) @@ -56,55 +126,8 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) for _, fi := range mi.fields.fieldsDB { - fieldType := fi.fieldType column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := "" - - checkColumn: - switch fieldType { - case TypeBooleanField: - col = T["bool"] - case TypeCharField: - col = fmt.Sprintf(T["string"], fi.size) - case TypeTextField: - col = T["string-text"] - case TypeDateField: - col = T["time.Time-date"] - case TypeDateTimeField: - col = T["time.Time"] - case TypeBitField: - col = T["int8"] - case TypeSmallIntegerField: - col = T["int16"] - case TypeIntegerField: - col = T["int32"] - case TypeBigIntegerField: - if al.Driver == DR_Sqlite { - fieldType = TypeIntegerField - goto checkColumn - } - col = T["int64"] - case TypePositiveBitField: - col = T["uint8"] - case TypePositiveSmallIntegerField: - col = T["uint16"] - case TypePositiveIntegerField: - col = T["uint32"] - case TypePositiveBigIntegerField: - col = T["uint64"] - case TypeFloatField: - col = T["float64"] - case TypeDecimalField: - s := T["float64-decimal"] - if strings.Index(s, "%d") == -1 { - col = s - } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) - } - case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - goto checkColumn - } + col := getColumnTyp(al, fi) if fi.auto { switch al.Driver { @@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) name := mi.table + "_" + strings.Join(names, "_") cols := strings.Join(names, sep) sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - tableIndexes[mi.table] = append(tableIndexes[mi.table], sql) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.Sql = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) } } diff --git a/orm/db.go b/orm/db.go index 7b30797f..433061f9 100644 --- a/orm/db.go +++ b/orm/db.go @@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { func (d *dbBase) DbTypes() map[string]string { return nil } + +func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { + tables := make(map[string]bool) + query := d.ins.ShowTablesQuery() + rows, err := db.Query(query) + if err != nil { + return tables, err + } + + for rows.Next() { + var table string + err := rows.Scan(&table) + if err != nil { + return tables, err + } + if table != "" { + tables[table] = true + } + } + + return tables, nil +} + +func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + columns := make(map[string][3]string) + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return columns, err + } + + for rows.Next() { + var ( + name string + typ string + null string + ) + err := rows.Scan(&name, &typ, &null) + if err != nil { + return columns, err + } + columns[name] = [3]string{name, typ, null} + } + + return columns, nil +} + +func (d *dbBase) ShowTablesQuery() string { + panic(ErrNotImplement) +} + +func (d *dbBase) ShowColumnsQuery(table string) string { + panic(ErrNotImplement) +} + +func (d *dbBase) IndexExists(dbQuerier, string, string) bool { + panic(ErrNotImplement) +} diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 08a9b509..da123079 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -1,5 +1,9 @@ package orm +import ( + "fmt" +) + var mysqlOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ?", @@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string { return mysqlTypes } +func (d *dbBaseMysql) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +func (d *dbBaseMysql) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + func newdbBaseMysql() dbBaser { b := new(dbBaseMysql) b.ins = b diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 5e20a110..4058fc10 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) return } +func (d *dbBasePostgres) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" +} + +func (d *dbBasePostgres) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) +} + func (d *dbBasePostgres) DbTypes() map[string]string { return postgresTypes } +func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) + row := db.QueryRow(query) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) b.ins = b diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index 46588125..7711ded0 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -1,6 +1,7 @@ package orm import ( + "database/sql" "fmt" ) @@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string { return sqliteTypes } +func (d *dbBaseSqlite) ShowTablesQuery() string { + return "SELECT name FROM sqlite_master WHERE type = 'table'" +} + +func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return nil, err + } + + columns := make(map[string][3]string) + for rows.Next() { + var tmp, name, typ, null sql.NullString + err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) + if err != nil { + return nil, err + } + columns[name.String] = [3]string{name.String, typ.String, null.String} + } + + return columns, nil +} + +func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { + return fmt.Sprintf("pragma table_info('%s')", table) +} + +func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("PRAGMA index_list('%s')", table) + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() + for rows.Next() { + var tmp, index sql.NullString + rows.Scan(&tmp, &index, &tmp) + if name == index.String { + return true + } + } + return false +} + func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) b.ins = b diff --git a/orm/orm_test.go b/orm/orm_test.go index 278c30ee..bbca3831 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Comment)) RegisterModel(new(UserBig)) - BootStrap() - - al := dataBaseCache.getDefault() - db := al.DB - - drops := getDbDropSql(al) - for _, query := range drops { - _, err := db.Exec(query) - throwFail(t, err, query) - } - - sqls, indexes := getDbCreateSql(al) - - for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} - queries = append(queries, indexes[mi.table]...) - - for _, query := range queries { - _, err := db.Exec(query) - throwFail(t, err, query) - } - } + err := RunSyncdb("default", true, false) + throwFail(t, err) modelCache.clean() } diff --git a/orm/types.go b/orm/types.go index b4fd5c84..ce25d037 100644 --- a/orm/types.go +++ b/orm/types.go @@ -133,4 +133,9 @@ type dbBaser interface { TimeFromDB(*time.Time, *time.Location) TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string + GetTables(dbQuerier) (map[string]bool, error) + GetColumns(dbQuerier, string) (map[string][3]string, error) + ShowTablesQuery() string + ShowColumnsQuery(string) string + IndexExists(dbQuerier, string, string) bool }