diff --git a/orm/README.md b/orm/README.md index 62eb8094..1976b911 100644 --- a/orm/README.md +++ b/orm/README.md @@ -33,7 +33,6 @@ import ( type User struct { Id int `orm:"auto"` Name string `orm:"size(100)"` - orm.Manager } func init() { @@ -72,7 +71,6 @@ type Post struct { Id int `orm:"auto"` Title string `orm:"size(100)"` User *User `orm:"rel(fk)"` - orm.Manager } var posts []*Post diff --git a/orm/db.go b/orm/db.go index 77c457ad..8f47e235 100644 --- a/orm/db.go +++ b/orm/db.go @@ -582,8 +582,6 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { return err } else { elm := reflect.New(mi.addrField.Elem().Type()) - md := elm.Interface().(Modeler) - md.Init(md) mind := reflect.Indirect(elm) d.setColsValues(mi, &mind, mi.fields.dbcols, refs) @@ -803,25 +801,27 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi val := reflect.ValueOf(container) ind := reflect.Indirect(val) - typ := ind.Type() errTyp := true one := true if val.Kind() == reflect.Ptr { - tp := typ + fn := "" if ind.Kind() == reflect.Slice { one = false if ind.Type().Elem().Kind() == reflect.Ptr { - tp = ind.Type().Elem().Elem() + typ := ind.Type().Elem().Elem() + fn = getFullName(typ) } + } else { + fn = getFullName(ind.Type()) } - errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName + errTyp = fn != mi.fullName } if errTyp { - panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName)) + panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", ind.Type(), mi.fullName, mi.fullName)) } rlimit := qs.limit @@ -873,8 +873,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } elm := reflect.New(mi.addrField.Elem().Type()) - md := elm.Interface().(Modeler) - md.Init(md) mind := reflect.Indirect(elm) cacheV := make(map[string]*reflect.Value) @@ -989,9 +987,9 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params if ind.Kind() == reflect.Struct { typ := ind.Type() - fullName := typ.PkgPath() + "." + typ.Name() + name := getFullName(typ) var value interface{} - if mmi, ok := modelCache.get(fullName); ok { + if mmi, ok := modelCache.getByFN(name); ok { if _, vu, exist := d.existPk(mmi, ind); exist { value = vu } @@ -999,7 +997,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params arg = value if arg == nil { - panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%v`", operator, val.Type())) + panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name)) } } else { arg = ind.Interface() @@ -1266,8 +1264,6 @@ setValue: if value != nil { fieldType = fi.relModelInfo.fields.pk.fieldType mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - md := mf.Interface().(Modeler) - md.Init(md) field.Set(mf) f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex) field = &f diff --git a/orm/docs/zh/Models.md b/orm/docs/zh/Models.md index d5ce7463..5f82e8eb 100644 --- a/orm/docs/zh/Models.md +++ b/orm/docs/zh/Models.md @@ -11,6 +11,17 @@ orm:"null;rel(fk)" 多个设置间使用 `;` 分隔,设置的值如果是多个,使用 `,` 分隔。 +#### 忽略字段 + +设置 `-` 即可忽略 struct 中的字段 + +```go +type User struct { +... + AnyField string `orm:"-"` +... +``` + #### auto 设置为 Autoincrement Primary Key @@ -49,23 +60,6 @@ type User struct { ... Status int `orm:"default(1)"` ``` -仅当进行 orm.Manager 初始化时才会赋值 -```go -func NewUser() *User { - obj := new(User) - obj.Manager.Init(obj) - return obj -} - -u := NewUser() -fmt.Println(u.Status) // 1 -``` -#### choices - -为字段设置一组可选的值,类型必须符合。其他值 clean 会返回错误 -```go -Status int `orm:"choices(1,2,3,4)"` -``` #### size (string) string 类型字段设置 size 以后,db type 将使用 varchar diff --git a/orm/docs/zh/Orm.md b/orm/docs/zh/Orm.md index 0030ac0b..edd44f17 100644 --- a/orm/docs/zh/Orm.md +++ b/orm/docs/zh/Orm.md @@ -17,14 +17,12 @@ type User struct { Id int `orm:"auto"` // 设置为auto主键 Name string Profile *Profile `orm:"rel(one)"` // OneToOne relation - orm.Manager // 每个model都需要定义orm.Manager } type Profile struct { Id int `orm:"auto"` Age int16 User *User `orm:"reverse(one)"` // 设置反向关系(可选) - orm.Manager } func init() { diff --git a/orm/models.go b/orm/models.go index 57099089..1c2270ce 100644 --- a/orm/models.go +++ b/orm/models.go @@ -18,6 +18,7 @@ var ( cacheByFN: make(map[string]*modelInfo), } supportTag = map[string]int{ + "-": 1, "null": 1, "blank": 1, "index": 1, @@ -27,7 +28,6 @@ var ( "auto_now": 1, "auto_now_add": 1, "size": 2, - "choices": 2, "column": 2, "default": 2, "rel": 2, @@ -67,9 +67,11 @@ func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { mi, ok = mc.cache[table] - if ok == false { - mi, ok = mc.cacheByFN[table] - } + return +} + +func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFN[name] return } diff --git a/orm/models_boot.go b/orm/models_boot.go index f7ca982c..10be132b 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -8,20 +8,36 @@ import ( "strings" ) -func registerModel(model Modeler) { - info := newModelInfo(model) - model.Init(model) - table := model.GetTableName() +func registerModel(model interface{}) { + val := reflect.ValueOf(model) + ind := reflect.Indirect(val) + typ := ind.Type() + + if val.Kind() != reflect.Ptr { + panic(fmt.Sprintf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + + info := newModelInfo(val) + + name := getFullName(typ) + if _, ok := modelCache.getByFN(name); ok { + fmt.Printf(" model `%s` redeclared, must be unique\n", name) + os.Exit(2) + } + + table := getTableName(val) if _, ok := modelCache.get(table); ok { - fmt.Printf("model <%T> redeclared, must be unique\n", model) + fmt.Printf(" table name `%s` redeclared, must be unique\n", table) os.Exit(2) } + if info.fields.pk == nil { - fmt.Printf("model <%T> need a primary key field\n", model) + fmt.Printf(" `%s` need a primary key field\n", name) os.Exit(2) } + info.table = table - info.pkg = getPkgPath(model) + info.pkg = typ.PkgPath() info.model = model info.manual = true modelCache.set(table, info) @@ -52,8 +68,8 @@ func bootStrap() { elm = elm.Elem() } - tn := getTableName(reflect.New(elm).Interface().(Modeler)) - mii, ok := modelCache.get(tn) + name := getFullName(elm) + mii, ok := modelCache.getByFN(name) if ok == false || mii.pkg != elm.PkgPath() { err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) goto end @@ -202,7 +218,7 @@ end: } } -func RegisterModel(models ...Modeler) { +func RegisterModel(models ...interface{}) { if modelCache.done { panic(fmt.Errorf("RegisterModel must be run begore BootStrap")) } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index bd8b4e94..0cbbdf65 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -7,30 +7,7 @@ import ( "strings" ) -type fieldChoices []StrTo - -func (f *fieldChoices) Add(s StrTo) { - if f.Have(s) == false { - *f = append(*f, s) - } -} - -func (f *fieldChoices) Clear() { - *f = fieldChoices([]StrTo{}) -} - -func (f *fieldChoices) Have(s StrTo) bool { - for _, v := range *f { - if v == s { - return true - } - } - return false -} - -func (f *fieldChoices) Clone() fieldChoices { - return *f -} +var errSkipField = errors.New("skip field") type fields struct { pk *fieldInfo @@ -111,7 +88,7 @@ type fieldInfo struct { name string fullName string column string - addrValue *reflect.Value + addrValue reflect.Value sf *reflect.StructField auto bool pk bool @@ -120,7 +97,6 @@ type fieldInfo struct { index bool unique bool initial StrTo - choices fieldChoices size int auto_now bool auto_now_add bool @@ -142,13 +118,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f var ( tag string tagValue string - choices fieldChoices - values fieldChoices initial StrTo fieldType int attrs map[string]bool tags map[string]string - parts []string addrField reflect.Value ) @@ -162,11 +135,20 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags) + if _, ok := attrs["-"]; ok { + return nil, errSkipField + } + digits := tags["digits"] decimals := tags["decimals"] size := tags["size"] onDelete := tags["on_delete"] + initial.Clear() + if v, ok := tags["default"]; ok { + initial.Set(v) + } + checkType: switch f := addrField.Interface().(type) { case Fielder: @@ -237,10 +219,6 @@ checkType: switch fieldType { case RelForeignKey, RelOneToOne, RelReverseOne: - if _, ok := addrField.Interface().(Modeler); ok == false { - err = fmt.Errorf("rel/reverse:one field must be implements Modeler") - goto end - } if field.Kind() != reflect.Ptr { err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) goto end @@ -254,10 +232,6 @@ checkType: err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) goto end } - if _, ok := reflect.New(field.Type().Elem()).Elem().Interface().(Modeler); ok == false { - err = fmt.Errorf("rel/reverse:many slice element must be implements Modeler") - goto end - } } } @@ -269,7 +243,7 @@ checkType: fi.fieldType = fieldType fi.name = sf.Name fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) - fi.addrValue = &addrField + fi.addrValue = addrField fi.sf = &sf fi.fullName = mi.fullName + "." + sf.Name @@ -306,7 +280,7 @@ checkType: switch onDelete { case od_CASCADE, od_DO_NOTHING: case od_SET_DEFAULT: - if tags["default"] == "" { + if initial.Exist() == false { err = errors.New("on_delete: set_default need set field a default value") goto end } @@ -397,31 +371,13 @@ checkType: fi.index = false } - parts = strings.Split(tags["choices"], ",") - if len(parts) > 1 { - for _, v := range parts { - choices.Add(StrTo(strings.TrimSpace(v))) - } - } - - initial.Clear() - if v, ok := tags["default"]; ok { - initial.Set(v) - } - if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { // can not set default - choices.Clear() initial.Clear() } - values = choices.Clone() - if initial.Exist() { - values.Add(initial) - } - - for i, v := range values { + v := initial switch fieldType { case TypeBooleanField: _, err = v.Bool() @@ -441,23 +397,11 @@ checkType: _, err = v.Uint64() } if err != nil { - if initial.Exist() && len(values) == i { - tag, tagValue = "default", tags["default"] - } else { - tag, tagValue = "choices", tags["choices"] - } + tag, tagValue = "default", tags["default"] goto wrongTag } } - if len(choices) > 0 && initial.Exist() { - if choices.Have(initial) == false { - err = fmt.Errorf("default value `%s` not in choices `%s`", tags["default"], tags["choices"]) - goto end - } - } - - fi.choices = choices fi.initial = initial end: if err != nil { diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 40b4bc8d..6737ced0 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -12,13 +12,13 @@ type modelInfo struct { name string fullName string table string - model Modeler + model interface{} fields *fields manual bool addrField reflect.Value } -func newModelInfo(model Modeler) (info *modelInfo) { +func newModelInfo(val reflect.Value) (info *modelInfo) { var ( err error fi *fieldInfo @@ -28,26 +28,24 @@ func newModelInfo(model Modeler) (info *modelInfo) { info = &modelInfo{} info.fields = newFields() - val := reflect.ValueOf(model) ind := reflect.Indirect(val) typ := ind.Type() info.addrField = ind.Addr() info.name = typ.Name() - info.fullName = typ.PkgPath() + "." + typ.Name() + info.fullName = getFullName(typ) for i := 0; i < ind.NumField(); i++ { field := ind.Field(i) sf = ind.Type().Field(i) - if field.CanAddr() { - addr := field.Addr() - if _, ok := addr.Interface().(*Manager); ok { + fi, err = newFieldInfo(info, field, sf) + + if err != nil { + if err == errSkipField { + err = nil continue } - } - fi, err = newFieldInfo(info, field, sf) - if err != nil { break } diff --git a/orm/models_manager.go b/orm/models_manager.go deleted file mode 100644 index aa4df9f4..00000000 --- a/orm/models_manager.go +++ /dev/null @@ -1,91 +0,0 @@ -package orm - -import () - -type fieldError struct { - name string - err error -} - -func (f *fieldError) Name() string { - return f.name -} - -func (f *fieldError) Error() error { - return f.err -} - -func NewFieldError(name string, err error) IFieldError { - return &fieldError{name, err} -} - -// non cleaned field errors -type fieldErrors struct { - errors map[string]IFieldError - errorList []IFieldError -} - -func (fe *fieldErrors) Get(name string) IFieldError { - return fe.errors[name] -} - -func (fe *fieldErrors) Set(name string, value IFieldError) { - fe.errors[name] = value -} - -func (fe *fieldErrors) List() []IFieldError { - return fe.errorList -} - -func NewFieldErrors() IFieldErrors { - return &fieldErrors{errors: make(map[string]IFieldError)} -} - -type Manager struct { - ins Modeler - inited bool -} - -// func (m *Manager) init(model reflect.Value) { -// elm := model.Elem() -// for i := 0; i < elm.NumField(); i++ { -// field := elm.Field(i) -// if _, ok := field.Interface().(Fielder); ok && field.CanSet() { -// if field.Elem().Kind() != reflect.Struct { -// field.Set(reflect.New(field.Type().Elem())) -// } -// } -// } -// } - -func (m *Manager) Init(model Modeler, args ...interface{}) Modeler { - if m.inited { - return m.ins - } - m.inited = true - m.ins = model - skipInitial := false - if len(args) > 0 { - if b, ok := args[0].(bool); ok && b { - skipInitial = true - } - } - _ = skipInitial - return model -} - -func (m *Manager) IsInited() bool { - return m.inited -} - -func (m *Manager) Clean() IFieldErrors { - return nil -} - -func (m *Manager) CleanFields(name string) IFieldErrors { - return nil -} - -func (m *Manager) GetTableName() string { - return getTableName(m.ins) -} diff --git a/orm/models_test.go b/orm/models_test.go index de611916..78853fcb 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -11,32 +11,30 @@ import ( ) type User struct { - Id int `orm:"auto"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"choices(0,1,2,3);defalut(0)"` - IsStaff bool `orm:"default(false)"` - IsActive bool `orm:"default(1)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - Manager `json:"-"` + Id int `orm:"auto"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 + IsStaff bool + IsActive bool `orm:"default(1)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` } func NewUser() *User { obj := new(User) - obj.Manager.Init(obj) return obj } type Profile struct { - Id int `orm:"auto"` - Age int16 `` - Money float64 `` - User *User `orm:"reverse(one)" json:"-"` - Manager `json:"-"` + Id int `orm:"auto"` + Age int16 `` + Money float64 `` + User *User `orm:"reverse(one)" json:"-"` } func (u *Profile) TableName() string { @@ -45,7 +43,6 @@ func (u *Profile) TableName() string { func NewProfile() *Profile { obj := new(Profile) - obj.Manager.Init(obj) return obj } @@ -57,25 +54,21 @@ type Post struct { Created time.Time `orm:"auto_now_add"` Updated time.Time `orm:"auto_now"` Tags []*Tag `orm:"rel(m2m)"` - Manager `json:"-"` } func NewPost() *Post { obj := new(Post) - obj.Manager.Init(obj) return obj } type Tag struct { - Id int `orm:"auto"` - Name string `orm:"size(30)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - Manager `json:"-"` + Id int `orm:"auto"` + Name string `orm:"size(30)"` + Posts []*Post `orm:"reverse(many)" json:"-"` } func NewTag() *Tag { obj := new(Tag) - obj.Manager.Init(obj) return obj } @@ -85,12 +78,10 @@ type Comment struct { Content string `` Parent *Comment `orm:"null;rel(fk)"` Created time.Time `orm:"auto_now_add"` - Manager `json:"-"` } func NewComment() *Comment { obj := new(Comment) - obj.Manager.Init(obj) return obj } diff --git a/orm/models_utils.go b/orm/models_utils.go index fafc033e..d15447a1 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -7,8 +7,11 @@ import ( "time" ) -func getTableName(model Modeler) string { - val := reflect.ValueOf(model) +func getFullName(typ reflect.Type) string { + return typ.PkgPath() + "." + typ.Name() +} + +func getTableName(val reflect.Value) string { ind := reflect.Indirect(val) fun := val.MethodByName("TableName") if fun.IsValid() { @@ -23,11 +26,6 @@ func getTableName(model Modeler) string { return snakeString(ind.Type().Name()) } -func getPkgPath(model Modeler) string { - val := reflect.ValueOf(model) - return val.Type().Elem().PkgPath() -} - func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := strings.ToLower(col) if column == "" { diff --git a/orm/orm.go b/orm/orm.go index d14ef152..868dde58 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -39,16 +39,21 @@ type orm struct { var _ Ormer = new(orm) -func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { - md.Init(md, true) - name := md.GetTableName() - if mi, ok := modelCache.get(name); ok { - return mi, reflect.Indirect(reflect.ValueOf(md)) +func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if val.Kind() != reflect.Ptr { + panic(fmt.Sprintf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } - panic(fmt.Sprintf(" table name: `%s` not exists", name)) + name := getFullName(typ) + if mi, ok := modelCache.getByFN(name); ok { + return mi, ind + } + panic(fmt.Sprintf(" table: `%s` not found, maybe not RegisterModel", name)) } -func (o *orm) Read(md Modeler) error { +func (o *orm) Read(md interface{}) error { mi, ind := o.getMiInd(md) err := o.alias.DbBaser.Read(o.db, mi, ind) if err != nil { @@ -57,7 +62,7 @@ func (o *orm) Read(md Modeler) error { return nil } -func (o *orm) Insert(md Modeler) (int64, error) { +func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) id, err := o.alias.DbBaser.Insert(o.db, mi, ind) if err != nil { @@ -71,7 +76,7 @@ func (o *orm) Insert(md Modeler) (int64, error) { return id, nil } -func (o *orm) Update(md Modeler) (int64, error) { +func (o *orm) Update(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) num, err := o.alias.DbBaser.Update(o.db, mi, ind) if err != nil { @@ -80,7 +85,7 @@ func (o *orm) Update(md Modeler) (int64, error) { return num, nil } -func (o *orm) Delete(md Modeler) (int64, error) { +func (o *orm) Delete(md interface{}) (int64, error) { mi, ind := o.getMiInd(md) num, err := o.alias.DbBaser.Delete(o.db, mi, ind) if err != nil { @@ -94,41 +99,48 @@ func (o *orm) Delete(md Modeler) (int64, error) { return num, nil } -func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) { +func (o *orm) M2mAdd(md interface{}, name string, mds ...interface{}) (int64, error) { // TODO panic(ErrNotImplement) return 0, nil } -func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) { +func (o *orm) M2mDel(md interface{}, name string, mds ...interface{}) (int64, error) { // TODO panic(ErrNotImplement) return 0, nil } -func (o *orm) LoadRel(md Modeler, name string) (int64, error) { +func (o *orm) LoadRel(md interface{}, name string) (int64, error) { // TODO panic(ErrNotImplement) return 0, nil } -func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { +func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { name := "" if table, ok := ptrStructOrTableName.(string); ok { name = snakeString(table) - } else if md, ok := ptrStructOrTableName.(Modeler); ok { - md.Init(md, true) - name = md.GetTableName() + if mi, ok := modelCache.get(name); ok { + qs = newQuerySet(o, mi) + } + } else { + val := reflect.ValueOf(ptrStructOrTableName) + ind := reflect.Indirect(val) + name = getFullName(ind.Type()) + if mi, ok := modelCache.getByFN(name); ok { + qs = newQuerySet(o, mi) + } } - if mi, ok := modelCache.get(name); ok { - return newQuerySet(o, mi) + if qs == nil { + panic(fmt.Sprintf(" table name: `%s` not exists", name)) } - panic(fmt.Sprintf(" table name: `%s` not exists", name)) + return } func (o *orm) Using(name string) error { if o.isTx { - panic(" transaction has been start, cannot change db") + panic(" transaction has been start, cannot change db") } if al, ok := dataBaseCache.get(name); ok { o.alias = al @@ -138,7 +150,7 @@ func (o *orm) Using(name string) error { o.db = al.DB } } else { - return errors.New(fmt.Sprintf(" unknown db alias name `%s`", name)) + return errors.New(fmt.Sprintf(" unknown db alias name `%s`", name)) } return nil } diff --git a/orm/orm_object.go b/orm/orm_object.go index b5935f1a..819a18bc 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -14,15 +14,19 @@ type insertSet struct { var _ Inserter = new(insertSet) -func (o *insertSet) Insert(md Modeler) (int64, error) { +func (o *insertSet) Insert(md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } - md.Init(md, true) val := reflect.ValueOf(md) ind := reflect.Indirect(val) - if val.Type() != o.mi.addrField.Type() { - panic(fmt.Sprintf(" need type `%s` but found `%s`", o.mi.addrField.Type(), val.Type())) + typ := ind.Type() + name := getFullName(typ) + if val.Kind() != reflect.Ptr { + panic(fmt.Sprintf(" cannot use non-ptr model struct `%s`", name)) + } + if name != o.mi.fullName { + panic(fmt.Sprintf(" need model `%s` but found `%s`", o.mi.fullName, name)) } id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind) if err != nil { diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 92178a6c..2f8c270f 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -63,7 +63,7 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter { case int: o.relDepth = val default: - panic(fmt.Sprintf(" wrong param kind: %v", val)) + panic(fmt.Sprintf(" wrong param kind: %v", val)) } } } @@ -96,7 +96,7 @@ func (o *querySet) All(container interface{}) (int64, error) { return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) } -func (o *querySet) One(container Modeler) error { +func (o *querySet) One(container interface{}) error { num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) if err != nil { return err diff --git a/orm/orm_test.go b/orm/orm_test.go index 9e29ad3a..119fe995 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -152,6 +152,14 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) { } } +func TestModelSyntax(t *testing.T) { + mi, ok := modelCache.get("user") + throwFail(t, AssertIs(ok, T_Equal, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true)) + } +} + func TestCRUD(t *testing.T) { profile := NewProfile() profile.Age = 30 diff --git a/orm/types.go b/orm/types.go index 01ac92ca..fc0c6bb3 100644 --- a/orm/types.go +++ b/orm/types.go @@ -18,22 +18,14 @@ type Fielder interface { Clean() error } -type Modeler interface { - Init(Modeler, ...interface{}) Modeler - IsInited() bool - Clean() IFieldErrors - CleanFields(string) IFieldErrors - GetTableName() string -} - type Ormer interface { - Read(Modeler) error - Insert(Modeler) (int64, error) - Update(Modeler) (int64, error) - Delete(Modeler) (int64, error) - M2mAdd(Modeler, string, ...interface{}) (int64, error) - M2mDel(Modeler, string, ...interface{}) (int64, error) - LoadRel(Modeler, string) (int64, error) + Read(interface{}) error + Insert(interface{}) (int64, error) + Update(interface{}) (int64, error) + Delete(interface{}) (int64, error) + M2mAdd(interface{}, string, ...interface{}) (int64, error) + M2mDel(interface{}, string, ...interface{}) (int64, error) + LoadRel(interface{}, string) (int64, error) QueryTable(interface{}) QuerySeter Using(string) error Begin() error @@ -44,7 +36,7 @@ type Ormer interface { } type Inserter interface { - Insert(Modeler) (int64, error) + Insert(interface{}) (int64, error) Close() error } @@ -61,7 +53,7 @@ type QuerySeter interface { Delete() (int64, error) PrepareInsert() (Inserter, error) All(interface{}) (int64, error) - One(Modeler) error + One(interface{}) error Values(*[]Params, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error)