diff --git a/pkg/client/orm/cmd.go b/pkg/client/orm/cmd.go index f03382e9..e03fc0ee 100644 --- a/pkg/client/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -99,13 +99,17 @@ func (d *commandSyncDb) Parse(args []string) { // Run orm line command. func (d *commandSyncDb) Run() error { var drops []string + var err error if d.force { - drops = getDbDropSQL(d.al) + drops, err = modelCache.getDbDropSQL(d.al) + if err != nil { + return err + } } db := d.al.DB - if d.force { + if d.force && len(drops) > 0 { for i, mi := range modelCache.allOrdered() { query := drops[i] if !d.noInfo { @@ -124,7 +128,10 @@ func (d *commandSyncDb) Run() error { } } - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } tables, err := d.al.DbBaser.GetTables(db) if err != nil { @@ -201,7 +208,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf("create table `%s` \n", mi.table) } - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } @@ -245,10 +252,13 @@ func (d *commandSQLAll) Parse(args []string) { // Run orm line command. func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } var all []string for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go index e045e847..8d6c0c33 100644 --- a/pkg/client/orm/cmd_utils.go +++ b/pkg/client/orm/cmd_utils.go @@ -16,7 +16,6 @@ package orm import ( "fmt" - "os" "strings" ) @@ -26,21 +25,6 @@ type dbIndex struct { SQL string } -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - // get database column type string. func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() @@ -140,146 +124,6 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { ) } -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands func getColumnDefault(fi *fieldInfo) string { var ( diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index c8fbcced..a7de10f7 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -15,7 +15,11 @@ package orm import ( + "errors" + "fmt" "reflect" + "runtime/debug" + "strings" "sync" ) @@ -95,12 +99,464 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { // clean all model info. func (mc *_modelCache) clean() { + mc.Lock() + defer mc.Unlock() + mc.orders = make([]string, 0) mc.cache = make(map[string]*modelInfo) mc.cacheByFullName = make(map[string]*modelInfo) mc.done = false } +//bootstrap bootstrap for models +func (mc *_modelCache) bootstrap() { + mc.Lock() + defer mc.Unlock() + if mc.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := mc.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := mc.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := mc.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + } + modelCache.done = true + return +} + +// register register models to model cache +func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { + if mc.done { + err = fmt.Errorf("register must be run before BootStrap") + return + } + + for _, model := range models { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + err = fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)) + return + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) + return + } + + table := getTableName(val) + + if prefixOrSuffixStr != "" { + if prefixOrSuffix { + table = prefixOrSuffixStr + table + } else { + table = table + prefixOrSuffixStr + } + } + + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := mc.getByFullName(name); ok { + err = fmt.Errorf(" model `%s` repeat register, must be unique\n", name) + return + } + + if _, ok := mc.get(table); ok { + err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) + return + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + return + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + mc.set(table, mi) + } + return +} + +//getDbDropSQL get database scheme drop sql queries +func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return queries,nil +} + +//getDbCreateSQL get database scheme creation sql queries +func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + queries = append(queries, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + // ResetModelCache Clean model cache. Then you can re-RegisterModel. // Common use this api for test case. func ResetModelCache() { diff --git a/pkg/client/orm/models_boot.go b/pkg/client/orm/models_boot.go index 8c56b3c4..407cf536 100644 --- a/pkg/client/orm/models_boot.go +++ b/pkg/client/orm/models_boot.go @@ -16,294 +16,8 @@ package orm import ( "fmt" - "os" - "reflect" - "runtime/debug" - "strings" ) -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - debug.PrintStack() - os.Exit(2) - } -} - // RegisterModel register models func RegisterModel(models ...interface{}) { if modelCache.done { @@ -314,34 +28,20 @@ func RegisterModel(models ...interface{}) { // RegisterModelWithPrefix register models with a prefix func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) + if err := modelCache.register(prefix, true, models...); err != nil { + panic(err) } } // RegisterModelWithSuffix register models with a suffix func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) + if err := modelCache.register(suffix, false, models...); err != nil { + panic(err) } } // BootStrap bootstrap models. // make all model parsed and can not add more models func BootStrap() { - modelCache.Lock() - defer modelCache.Unlock() - if modelCache.done { - return - } - bootStrap() - modelCache.done = true + modelCache.bootstrap() }