1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 21:11:29 +00:00

optimize the ORM

This commit is contained in:
astaxie 2016-08-31 00:07:19 +08:00
parent 161c061376
commit 8c37a07adb
9 changed files with 119 additions and 124 deletions

View File

@ -154,7 +154,7 @@ outFor:
typ := val.Type() typ := val.Type()
name := getFullName(typ) name := getFullName(typ)
var value interface{} var value interface{}
if mmi, ok := modelCache.getByFN(name); ok { if mmi, ok := modelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist { if _, vu, exist := getExistPk(mmi, val); exist {
value = vu value = vu
} }

View File

@ -29,39 +29,18 @@ const (
var ( var (
modelCache = &_modelCache{ modelCache = &_modelCache{
cache: make(map[string]*modelInfo), cache: make(map[string]*modelInfo),
cacheByFN: make(map[string]*modelInfo), cacheByFullName: make(map[string]*modelInfo),
}
supportTag = map[string]int{
"-": 1,
"null": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
} }
) )
// model info collection // model info collection
type _modelCache struct { type _modelCache struct {
sync.RWMutex // only used outsite for bootStrap sync.RWMutex // only used outsite for bootStrap
orders []string orders []string
cache map[string]*modelInfo cache map[string]*modelInfo
cacheByFN map[string]*modelInfo cacheByFullName map[string]*modelInfo
done bool done bool
} }
// get all model info // get all model info
@ -88,9 +67,9 @@ func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
return return
} }
// get model info by field name // get model info by full name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name] mi, ok = mc.cacheByFullName[name]
return return
} }
@ -98,7 +77,7 @@ func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table] mii := mc.cache[table]
mc.cache[table] = mi mc.cache[table] = mi
mc.cacheByFN[mi.fullName] = mi mc.cacheByFullName[mi.fullName] = mi
if mii == nil { if mii == nil {
mc.orders = append(mc.orders, table) mc.orders = append(mc.orders, table)
} }
@ -109,7 +88,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
func (mc *_modelCache) clean() { func (mc *_modelCache) clean() {
mc.orders = make([]string, 0) mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo) mc.cache = make(map[string]*modelInfo)
mc.cacheByFN = make(map[string]*modelInfo) mc.cacheByFullName = make(map[string]*modelInfo)
mc.done = false mc.done = false
} }

View File

@ -26,12 +26,14 @@ import (
// prefix means table name prefix. // prefix means table name prefix.
func registerModel(prefix string, model interface{}) { func registerModel(prefix string, model interface{}) {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
ind := reflect.Indirect(val) typ := reflect.Indirect(val).Type()
typ := ind.Type()
if val.Kind() != reflect.Ptr { if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
// For this case:
// u := &User{}
// registerModel(&u)
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
} }
@ -41,9 +43,9 @@ func registerModel(prefix string, model interface{}) {
if prefix != "" { if prefix != "" {
table = prefix + table table = prefix + table
} }
// models's fullname is pkgpath + struct name
name := getFullName(typ) name := getFullName(typ)
if _, ok := modelCache.getByFN(name); ok { if _, ok := modelCache.getByFullName(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name) fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
os.Exit(2) os.Exit(2)
} }
@ -110,7 +112,7 @@ func bootStrap() {
} }
name := getFullName(elm) name := getFullName(elm)
mii, ok := modelCache.getByFN(name) mii, ok := modelCache.getByFullName(name)
if ok == false || mii.pkg != elm.PkgPath() { if ok == false || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end goto end
@ -123,7 +125,7 @@ func bootStrap() {
msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i] pn := fi.relThrough[:i]
rmi, ok := modelCache.getByFN(fi.relThrough) rmi, ok := modelCache.getByFullName(fi.relThrough)
if ok == false || pn != rmi.pkg { if ok == false || pn != rmi.pkg {
err = errors.New(msg + " cannot find table") err = errors.New(msg + " cannot find table")
goto end goto end

View File

@ -152,6 +152,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
fi = new(fieldInfo) fi = new(fieldInfo)
// if field which CanAddr is the follow type
// A value is addressable if it is an element of a slice,
// an element of an addressable array, a field of an
// addressable struct, or the result of dereferencing a pointer.
addrField = field addrField = field
if field.CanAddr() && field.Kind() != reflect.Ptr { if field.CanAddr() && field.Kind() != reflect.Ptr {
addrField = field.Addr() addrField = field.Addr()
@ -162,7 +166,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
} }
} }
parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags) attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
if _, ok := attrs["-"]; ok { if _, ok := attrs["-"]; ok {
return nil, errSkipField return nil, errSkipField
@ -188,7 +192,7 @@ checkType:
} }
fieldType = f.FieldType() fieldType = f.FieldType()
if fieldType&IsRelField > 0 { if fieldType&IsRelField > 0 {
err = fmt.Errorf("unsupport rel type custom field") err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42")
goto end goto end
} }
default: default:

View File

@ -29,31 +29,25 @@ type modelInfo struct {
model interface{} model interface{}
fields *fields fields *fields
manual bool manual bool
addrField reflect.Value addrField reflect.Value //store the original struct value
uniques []string uniques []string
isThrough bool isThrough bool
} }
// new model info // new model info
func newModelInfo(val reflect.Value) (info *modelInfo) { func newModelInfo(val reflect.Value) (mi *modelInfo) {
mi = &modelInfo{}
info = &modelInfo{} mi.fields = newFields()
info.fields = newFields()
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() mi.addrField = val
mi.name = ind.Type().Name()
info.addrField = val mi.fullName = getFullName(ind.Type())
addModelFields(mi, ind, "", []int{})
info.name = typ.Name()
info.fullName = getFullName(typ)
addModelFields(info, ind, "", []int{})
return return
} }
func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []int) { // index: FieldByIndex returns the nested field corresponding to index
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var ( var (
err error err error
fi *fieldInfo fi *fieldInfo
@ -63,43 +57,39 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
for i := 0; i < ind.NumField(); i++ { for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i) field := ind.Field(i)
sf = ind.Type().Field(i) sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" { if sf.PkgPath != "" {
continue continue
} }
// add anonymous struct fields // add anonymous struct fields
if sf.Anonymous { if sf.Anonymous {
addModelFields(info, field, mName+"."+sf.Name, append(index, i)) addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue continue
} }
fi, err = newFieldInfo(info, field, sf, mName) fi, err = newFieldInfo(mi, field, sf, mName)
if err == errSkipField {
if err != nil { err = nil
if err == errSkipField { continue
err = nil } else if err != nil {
continue
}
break break
} }
//record current field index
added := info.fields.Add(fi) fi.fieldIndex = append(index, i)
if added == false { fi.mi = mi
fi.inModel = true
if mi.fields.Add(fi) == false {
err = fmt.Errorf("duplicate column name: %s", fi.column) err = fmt.Errorf("duplicate column name: %s", fi.column)
break break
} }
if fi.pk { if fi.pk {
if info.fields.pk != nil { if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only") err = fmt.Errorf("one model must have one pk field only")
break break
} else { } else {
info.fields.pk = fi mi.fields.pk = fi
} }
} }
fi.fieldIndex = append(index, i)
fi.mi = info
fi.inModel = true
} }
if err != nil { if err != nil {
@ -110,12 +100,12 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
// combine related model info to new model info. // combine related model info to new model info.
// prepare for relation models query. // prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
info = new(modelInfo) mi = new(modelInfo)
info.fields = newFields() mi.fields = newFields()
info.table = m1.table + "_" + m2.table + "s" mi.table = m1.table + "_" + m2.table + "s"
info.name = camelString(info.table) mi.name = camelString(mi.table)
info.fullName = m1.pkg + "." + info.name mi.fullName = m1.pkg + "." + mi.name
fa := new(fieldInfo) fa := new(fieldInfo)
f1 := new(fieldInfo) f1 := new(fieldInfo)
@ -126,7 +116,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
fa.dbcol = true fa.dbcol = true
fa.name = "Id" fa.name = "Id"
fa.column = "id" fa.column = "id"
fa.fullName = info.fullName + "." + fa.name fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true f1.dbcol = true
f2.dbcol = true f2.dbcol = true
@ -134,8 +124,8 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.fieldType = RelForeignKey f2.fieldType = RelForeignKey
f1.name = camelString(m1.table) f1.name = camelString(m1.table)
f2.name = camelString(m2.table) f2.name = camelString(m2.table)
f1.fullName = info.fullName + "." + f1.name f1.fullName = mi.fullName + "." + f1.name
f2.fullName = info.fullName + "." + f2.name f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id" f1.column = m1.table + "_id"
f2.column = m2.table + "_id" f2.column = m2.table + "_id"
f1.rel = true f1.rel = true
@ -144,14 +134,14 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.relTable = m2.table f2.relTable = m2.table
f1.relModelInfo = m1 f1.relModelInfo = m1
f2.relModelInfo = m2 f2.relModelInfo = m2
f1.mi = info f1.mi = mi
f2.mi = info f2.mi = mi
info.fields.Add(fa) mi.fields.Add(fa)
info.fields.Add(f1) mi.fields.Add(f1)
info.fields.Add(f2) mi.fields.Add(f2)
info.fields.pk = fa mi.fields.pk = fa
info.uniques = []string{f1.column, f2.column} mi.uniques = []string{f1.column, f2.column}
return return
} }

View File

@ -22,25 +22,47 @@ import (
"time" "time"
) )
// 1 is attr
// 2 is tag
var supportTag = map[string]int{
"-": 1,
"null": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
}
// get reflect.Type name with package path. // get reflect.Type name with package path.
func getFullName(typ reflect.Type) string { func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name() return typ.PkgPath() + "." + typ.Name()
} }
// get table name. method, or field name. auto snaked. // getTableName get struct table name.
// If the struct implement the TableName, then get the result as tablename
// else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string { func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val) if fun := val.MethodByName("TableName"); fun.IsValid() {
fun := val.MethodByName("TableName")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{}) vals := fun.Call([]reflect.Value{})
if len(vals) > 0 { // has return and the first val is string
val := vals[0] if len(vals) > 0 && vals[0].Kind() == reflect.String {
if val.Kind() == reflect.String { return vals[0].String()
return val.String()
}
} }
} }
return snakeString(ind.Type().Name()) return snakeString(reflect.Indirect(val).Type().Name())
} }
// get table engine, mysiam or innodb. // get table engine, mysiam or innodb.
@ -189,21 +211,25 @@ func getFieldType(val reflect.Value) (ft int, err error) {
} }
// parse struct tag string // parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
attr := make(map[string]bool) attrs = make(map[string]bool)
tag := make(map[string]string) tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) { for _, v := range strings.Split(data, defaultStructTagDelim) {
if v == "" {
continue
}
v = strings.TrimSpace(v) v = strings.TrimSpace(v)
if t := strings.ToLower(v); supportTag[t] == 1 { if t := strings.ToLower(v); supportTag[t] == 1 {
attr[t] = true attrs[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := t[:i] name := t[:i]
if supportTag[name] == 2 { if supportTag[name] == 2 {
v = v[i+1 : len(v)-1] v = v[i+1 : len(v)-1]
tag[name] = v tags[name] = v
} }
} else {
DebugLog.Println("unsupport orm tag", v)
} }
} }
*attrs = attr return
*tags = tag
} }

View File

@ -104,7 +104,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
name := getFullName(typ) name := getFullName(typ)
if mi, ok := modelCache.getByFN(name); ok { if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind return mi, ind
} }
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
@ -427,7 +427,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
} }
} else { } else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
if mi, ok := modelCache.getByFN(name); ok { if mi, ok := modelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi) qs = newQuerySet(o, mi)
} }
} }

View File

@ -286,7 +286,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
structMode = true structMode = true
fn := getFullName(typ) fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok { if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi sMi = mi
} }
} else { } else {
@ -355,12 +355,9 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
for i := 0; i < ind.NumField(); i++ { for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i) f := ind.Field(i)
fe := ind.Type().Field(i) fe := ind.Type().Field(i)
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string var col string
if col = tags["column"]; len(col) == 0 { if col = tags["column"]; col == "" {
col = snakeString(fe.Name) col = snakeString(fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {
@ -422,7 +419,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
structMode = true structMode = true
fn := getFullName(typ) fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok { if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi sMi = mi
} }
} else { } else {
@ -499,12 +496,9 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
for i := 0; i < ind.NumField(); i++ { for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i) f := ind.Field(i)
fe := ind.Type().Field(i) fe := ind.Type().Field(i)
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string var col string
if col = tags["column"]; len(col) == 0 { if col = tags["column"]; col == "" {
col = snakeString(fe.Name) col = snakeString(fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {

View File

@ -227,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
user := &User{} user := &User{}
ind := reflect.ValueOf(user).Elem() ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type()) fn := getFullName(ind.Type())
mi, ok := modelCache.getByFN(fn) mi, ok := modelCache.getByFullName(fn)
throwFail(t, AssertIs(ok, true)) throwFail(t, AssertIs(ok, true))
mi, ok = modelCache.get("user") mi, ok = modelCache.get("user")