diff --git a/orm/db_alias.go b/orm/db_alias.go index 7cd29f03..8d474591 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -21,8 +21,12 @@ const ( var ( dataBaseCache = &_dbCache{cache: make(map[string]*alias)} - drivers = make(map[string]driverType) - dbBasers = map[driverType]dbBaser{ + drivers = map[string]driverType{ + "mysql": DR_MySQL, + "postgres": DR_Postgres, + "sqlite3": DR_Sqlite, + } + dbBasers = map[driverType]dbBaser{ DR_MySQL: newdbBaseMysql(), DR_Sqlite: newdbBaseSqlite(), DR_Oracle: newdbBaseMysql(), @@ -122,9 +126,3 @@ func RegisterDriver(name string, typ driverType) { } } } - -func init() { - RegisterDriver("mysql", DR_MySQL) - RegisterDriver("postgres", DR_Postgres) - RegisterDriver("sqlite3", DR_Sqlite) -} diff --git a/orm/models_manager.go b/orm/models_manager.go index b4d05219..aa4df9f4 100644 --- a/orm/models_manager.go +++ b/orm/models_manager.go @@ -58,12 +58,19 @@ type Manager struct { // } // } -func (m *Manager) Init(model Modeler) Modeler { +func (m *Manager) Init(model Modeler, args ...interface{}) Modeler { if m.inited { return m.ins } m.inited = true m.ins = model + skipInitial := false + if len(args) > 0 { + if b, ok := args[0].(bool); ok && b { + skipInitial = true + } + } + _ = skipInitial return model } diff --git a/orm/orm.go b/orm/orm.go index 769b86c5..9fff0d08 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -27,6 +27,7 @@ type orm struct { } func (o *orm) Object(md Modeler) ObjectSeter { + md.Init(md, true) name := md.GetTableName() if mi, ok := modelCache.get(name); ok { return newObject(o, mi, md) @@ -38,8 +39,9 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { name := "" if table, ok := ptrStructOrTableName.(string); ok { name = snakeString(table) - } else if m, ok := ptrStructOrTableName.(Modeler); ok { - name = m.GetTableName() + } else if md, ok := ptrStructOrTableName.(Modeler); ok { + md.Init(md, true) + name = md.GetTableName() } if mi, ok := modelCache.get(name); ok { return newQuerySet(o, mi) diff --git a/orm/orm_object.go b/orm/orm_object.go index 341a86e2..323675f0 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -17,6 +17,7 @@ func (o *insertSet) Insert(md Modeler) (int64, error) { if o.closed { return 0, ErrStmtClosed } + md.Init(md, true) val := reflect.ValueOf(md) ind := reflect.Indirect(val) if val.Type() != o.mi.addrField.Type() { diff --git a/orm/types.go b/orm/types.go index c819da2f..e73a11aa 100644 --- a/orm/types.go +++ b/orm/types.go @@ -14,7 +14,7 @@ type Fielder interface { } type Modeler interface { - Init(Modeler) Modeler + Init(Modeler, ...interface{}) Modeler IsInited() bool Clean() IFieldErrors CleanFields(string) IFieldErrors