diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go index 24f564ab..19941d2e 100644 --- a/pkg/client/orm/models.go +++ b/pkg/client/orm/models.go @@ -33,10 +33,7 @@ const ( ) var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } + modelCache = NewModelCacheHandler() ) // model info collection @@ -48,6 +45,14 @@ type _modelCache struct { done bool } +//NewModelCacheHandler generator of _modelCache +func NewModelCacheHandler() *_modelCache { + return &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +} + // get all model info func (mc *_modelCache) all() map[string]*modelInfo { m := make(map[string]*modelInfo, len(mc.cache)) @@ -321,7 +326,7 @@ end: fmt.Println(err) debug.PrintStack() } - modelCache.done = true + mc.done = true return } @@ -404,14 +409,14 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m //getDbDropSQL get database scheme drop sql queries func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { - if len(modelCache.cache) == 0 { + if len(mc.cache) == 0 { err = errors.New("no Model found, need register your model") return } Q := al.DbBaser.TableQuote() - for _, mi := range modelCache.allOrdered() { + for _, mi := range mc.allOrdered() { queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) } return queries, nil @@ -419,7 +424,7 @@ func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { //getDbCreateSQL get database scheme creation sql queries func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { - if len(modelCache.cache) == 0 { + if len(mc.cache) == 0 { err = errors.New("no Model found, need register your model") return } @@ -430,7 +435,7 @@ func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes tableIndexes = make(map[string][]dbIndex) - for _, mi := range modelCache.allOrdered() { + for _, mi := range mc.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) diff --git a/pkg/client/orm/models_boot.go b/pkg/client/orm/models_boot.go index 407cf536..9a0ce893 100644 --- a/pkg/client/orm/models_boot.go +++ b/pkg/client/orm/models_boot.go @@ -14,15 +14,8 @@ package orm -import ( - "fmt" -) - // RegisterModel register models func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } RegisterModelWithPrefix("", models...) }