diff --git a/orm/cmd.go b/orm/cmd.go index 174091f4..59f0833e 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -99,18 +99,23 @@ func (d *commandSyncDb) Run() { } } - tables := getDbCreateSql(d.al) + sqls, indexes := getDbCreateSql(d.al) for i, mi := range modelCache.allOrdered() { - query := tables[i] - _, err := db.Exec(query) fmt.Printf("create table `%s` \n", mi.table) - if d.verbose { - query = " " + strings.Join(strings.Split(query, "\n"), "\n ") - fmt.Println(query) - } - if err != nil { - fmt.Printf(" %s\n", err.Error()) + + queries := []string{sqls[i]} + queries = append(queries, indexes[mi.table]...) + + for _, query := range queries { + _, err := db.Exec(query) + if d.verbose { + query = " " + strings.Join(strings.Split(query, "\n"), "\n ") + fmt.Println(query) + } + if err != nil { + fmt.Printf(" %s\n", err.Error()) + } } if d.verbose { fmt.Println("") @@ -133,9 +138,15 @@ func (d *commandSqlAll) Parse(args []string) { } func (d *commandSqlAll) Run() { - sqls := getDbCreateSql(d.al) - sql := strings.Join(sqls, "\n\n") - fmt.Println(sql) + sqls, indexes := getDbCreateSql(d.al) + var all []string + for i, mi := range modelCache.allOrdered() { + queries := []string{sqls[i]} + queries = append(queries, indexes[mi.table]...) + sql := strings.Join(queries, "\n") + all = append(all, sql) + } + fmt.Println(strings.Join(all, "\n\n")) } func init() { diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 120da7c5..61cd94b1 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -31,7 +31,7 @@ func getDbDropSql(al *alias) (sqls []string) { return sqls } -func getDbCreateSql(al *alias) (sqls []string) { +func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") os.Exit(2) @@ -41,6 +41,8 @@ func getDbCreateSql(al *alias) (sqls []string) { T := al.DbBaser.DbTypes() sep := fmt.Sprintf("%s, %s", Q, Q) + tableIndexes = make(map[string][]string) + for _, mi := range modelCache.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) @@ -125,7 +127,7 @@ func getDbCreateSql(al *alias) (sqls []string) { } if fi.index { - sqlIndexes = append(sqlIndexes, []string{column}) + sqlIndexes = append(sqlIndexes, []string{fi.column}) } } @@ -179,10 +181,10 @@ func getDbCreateSql(al *alias) (sqls []string) { name := strings.Join(names, "_") cols := strings.Join(names, sep) sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - sqls = append(sqls, sql) + tableIndexes[mi.table] = append(tableIndexes[mi.table], sql) } } - return sqls + return } diff --git a/orm/models.go b/orm/models.go index 98d0aba3..277e5d0f 100644 --- a/orm/models.go +++ b/orm/models.go @@ -58,8 +58,8 @@ func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo { m := make([]*modelInfo, 0, len(mc.orders)) - for _, v := range mc.cache { - m = append(m, v) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) } return m } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 7316b612..29ac3700 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -344,13 +344,6 @@ checkType: err = fmt.Errorf("non-integer type cannot set auto") goto end } - - if fi.pk || fi.index || fi.unique { - if fieldType != TypeCharField && fieldType != RelOneToOne { - err = fmt.Errorf("cannot set pk/index/unique") - goto end - } - } } if fi.auto || fi.pk { diff --git a/orm/orm_test.go b/orm/orm_test.go index 1eb8764c..278c30ee 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -206,13 +206,19 @@ func TestSyncDb(t *testing.T) { drops := getDbDropSql(al) for _, query := range drops { _, err := db.Exec(query) - throwFailNow(t, err, query) + throwFail(t, err, query) } - tables := getDbCreateSql(al) - for _, query := range tables { - _, err := db.Exec(query) - throwFailNow(t, err, query) + sqls, indexes := getDbCreateSql(al) + + for i, mi := range modelCache.allOrdered() { + queries := []string{sqls[i]} + queries = append(queries, indexes[mi.table]...) + + for _, query := range queries { + _, err := db.Exec(query) + throwFail(t, err, query) + } } modelCache.clean()