package orm import ( "database/sql" "errors" "fmt" "os" "reflect" "time" ) const ( Debug_Queries = iota ) var ( // DebugLevel = Debug_Queries Debug = false DebugLog = NewLog(os.Stderr) DefaultRowsLimit = 1000 DefaultRelsDepth = 2 DefaultTimeLoc = time.Local ErrTxHasBegan = errors.New(" transaction already begin") ErrTxDone = errors.New(" transaction not begin") ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") ErrArgs = errors.New(" args error may be empty") ErrNotImplement = errors.New("have not implement") ) type Params map[string]interface{} type ParamsList []interface{} type orm struct { alias *alias db dbQuerier isTx bool } var _ Ormer = new(orm) func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() if needPtr && val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } name := getFullName(typ) if mi, ok := modelCache.getByFN(name); ok { return mi, ind } panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) } func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) } return fi } func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return err } return nil } func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } o.setPk(mi, ind, id) return id, nil } func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) } else { ind.Field(mi.fields.pk.fieldIndex).SetInt(id) } } } func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { var cnt int64 sind := reflect.Indirect(reflect.ValueOf(mds)) switch sind.Kind() { case reflect.Array, reflect.Slice: if sind.Len() == 0 { return cnt, ErrArgs } default: return cnt, ErrArgs } if bulk <= 1 { for i := 0; i < sind.Len(); i++ { ind := sind.Index(i) mi, _ := o.getMiInd(ind.Interface(), false) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } o.setPk(mi, ind, id) cnt += 1 } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } return num, nil } func (o *orm) Delete(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) if err != nil { return num, err } if num > 0 { o.setPk(mi, ind, 0) } return num, nil } func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) switch { case fi.fieldType == RelManyToMany: case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: default: panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) } return newQueryM2M(md, o, mi, fi, ind) } func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) qs := qseter.(*querySet) var relDepth int var limit, offset int64 var order string for i, arg := range args { switch i { case 0: if v, ok := arg.(bool); ok { if v { relDepth = DefaultRelsDepth } } else if v, ok := arg.(int); ok { relDepth = v } case 1: limit = ToInt64(arg) case 2: offset = ToInt64(arg) case 3: order, _ = arg.(string) } } switch fi.fieldType { case RelOneToOne, RelForeignKey, RelReverseOne: limit = 1 offset = 0 } qs.limit = limit qs.offset = offset qs.relDepth = relDepth if len(order) > 0 { qs.orders = []string{order} } find := ind.Field(fi.fieldIndex) var nums int64 var err error switch fi.fieldType { case RelOneToOne, RelForeignKey, RelReverseOne: val := reflect.New(find.Type().Elem()) container := val.Interface() err = qs.One(container) if err == nil { find.Set(val) nums = 1 } default: nums, err = qs.All(find.Addr().Interface()) } return nums, err } func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) if exist == false { panic(ErrMissPK) } var qs *querySet switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: if !fi.inModel { break } qs = o.getRelQs(md, mi, fi) case RelReverseOne, RelReverseMany: if !fi.inModel { break } qs = o.getReverseQs(md, mi, fi) } if qs == nil { panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field")) } return mi, fi, ind, qs } func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: default: panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) } var q *querySet if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { q = newQuerySet(o, fi.relModelInfo).(*querySet) q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) } else { q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) } return q } func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: default: panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) } q := newQuerySet(o, fi.relModelInfo).(*querySet) q.cond = NewCondition() if fi.fieldType == RelManyToMany { q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) } else { q.cond = q.cond.And(fi.reverseFieldInfo.column, md) } return q } func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { name := "" if table, ok := ptrStructOrTableName.(string); ok { name = snakeString(table) if mi, ok := modelCache.get(name); ok { qs = newQuerySet(o, mi) } } else { name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) if mi, ok := modelCache.getByFN(name); ok { qs = newQuerySet(o, mi) } } if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } return } func (o *orm) Using(name string) error { if o.isTx { panic(fmt.Errorf(" transaction has been start, cannot change db")) } if al, ok := dataBaseCache.get(name); ok { o.alias = al if Debug { o.db = newDbQueryLog(al, al.DB) } else { o.db = al.DB } } else { return fmt.Errorf(" unknown db alias name `%s`", name) } return nil } func (o *orm) Begin() error { if o.isTx { return ErrTxHasBegan } var tx *sql.Tx tx, err := o.db.(txer).Begin() if err != nil { return err } o.isTx = true if Debug { o.db.(*dbQueryLog).SetDB(tx) } else { o.db = tx } return nil } func (o *orm) Commit() error { if o.isTx == false { return ErrTxDone } err := o.db.(txEnder).Commit() if err == nil { o.isTx = false o.Using(o.alias.Name) } else if err == sql.ErrTxDone { return ErrTxDone } return err } func (o *orm) Rollback() error { if o.isTx == false { return ErrTxDone } err := o.db.(txEnder).Rollback() if err == nil { o.isTx = false o.Using(o.alias.Name) } else if err == sql.ErrTxDone { return ErrTxDone } return err } func (o *orm) Raw(query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } func (o *orm) Driver() Driver { return driver(o.alias.Name) } func NewOrm() Ormer { BootStrap() // execute only once o := new(orm) err := o.Using("default") if err != nil { panic(err) } return o }