From 4c061feddf0611b10f5fea3a39026b871e1c242f Mon Sep 17 00:00:00 2001 From: slene Date: Thu, 22 Aug 2013 21:19:58 +0800 Subject: [PATCH] orm support custom multi unique / index --- orm/cmd_utils.go | 45 ++++++++++++++++++++++++++++++++++++++++++++ orm/models_info_m.go | 2 +- orm/models_test.go | 12 ++++++++++++ orm/models_utils.go | 32 +++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 602d102a..120da7c5 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -39,6 +39,7 @@ func getDbCreateSql(al *alias) (sqls []string) { Q := al.DbBaser.TableQuote() T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) for _, mi := range modelCache.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) @@ -49,6 +50,8 @@ func getDbCreateSql(al *alias) (sqls []string) { columns := make([]string, 0, len(mi.fields.fieldsDB)) + sqlIndexes := [][]string{} + for _, fi := range mi.fields.fieldsDB { fieldType := fi.fieldType @@ -120,6 +123,10 @@ func getDbCreateSql(al *alias) (sqls []string) { if fi.unique { column += " " + "UNIQUE" } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{column}) + } } if strings.Index(column, "%COL%") != -1 { @@ -129,6 +136,21 @@ func getDbCreateSql(al *alias) (sqls []string) { columns = append(columns, column) } + if mi.model != nil { + for _, names := range getTableUnique(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + sql += strings.Join(columns, ",\n") sql += "\n)" @@ -136,7 +158,30 @@ func getDbCreateSql(al *alias) (sqls []string) { sql += " ENGINE=INNODB" } + sql += ";" sqls = append(sqls, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + sqls = append(sqls, sql) + } + } return sqls diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 31e849b2..5ede53dd 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -31,7 +31,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { ind := reflect.Indirect(val) typ := ind.Type() - info.addrField = ind.Addr() + info.addrField = val info.name = typ.Name() info.fullName = getFullName(typ) diff --git a/orm/models_test.go b/orm/models_test.go index 9aafdf27..2367e5f0 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -78,6 +78,18 @@ type User struct { ShouldSkip string `orm:"-"` } +func (u *User) TableIndex() [][]string { + return [][]string{ + []string{"Id", "UserName"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + []string{"UserName", "Email"}, + } +} + func NewUser() *User { obj := new(User) return obj diff --git a/orm/models_utils.go b/orm/models_utils.go index d4677a76..7f50c7d0 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -26,6 +26,38 @@ func getTableName(val reflect.Value) string { return snakeString(ind.Type().Name()) } +func getTableIndex(val reflect.Value) [][]string { + fun := val.MethodByName("TableIndex") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 { + val := vals[0] + if val.CanInterface() { + if d, ok := val.Interface().([][]string); ok { + return d + } + } + } + } + return nil +} + +func getTableUnique(val reflect.Value) [][]string { + fun := val.MethodByName("TableUnique") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 { + val := vals[0] + if val.CanInterface() { + if d, ok := val.Interface().([][]string); ok { + return d + } + } + } + } + return nil +} + func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := strings.ToLower(col) if column == "" {