diff --git a/orm/db_tables.go b/orm/db_tables.go index 972077c2..5a78cf21 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -112,7 +112,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { names = append(names, fi.name) mmi = fi.relModelInfo - if fi.null { + if fi.null || t.skipEnd { inner = false } @@ -189,6 +189,8 @@ func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { var ( jtl *dbTable + fi *fieldInfo + fiN *fieldInfo mmi = mi ) @@ -197,9 +199,24 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string inner := true +loopFor: for i, ex := range exprs { - fi, ok := mmi.fields.GetByAny(ex) + var ok, okN bool + + if fiN != nil { + fi = fiN + ok = true + fiN = nil + } + + if i == 0 { + fi, ok = mmi.fields.GetByAny(ex) + } + + // fmt.Println(ex, fi.name, fiN) + + _ = okN if ok { @@ -217,13 +234,20 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string mmi = fi.reverseFieldInfo.mi } + if i < num { + fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + } + if isRel && (fi.mi.isThrough == false || num != i) { - if fi.null { + if fi.null || t.skipEnd { inner = false } - if num == i && t.skipEnd { - } else { + if t.skipEnd && okN || !t.skipEnd { + if t.skipEnd && okN && fiN.pk { + goto loopEnd + } + jt, _ := t.add(names, mmi, fi, inner) jt.jtl = jtl jtl = jt @@ -231,34 +255,40 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string } - if num == i { - if i == 0 || jtl == nil { - index = "T0" - } else { + if num != i { + continue + } + + loopEnd: + + if i == 0 || jtl == nil { + index = "T0" + } else { + index = jtl.index + } + + info = fi + + if jtl == nil { + name = fi.name + } else { + name = jtl.name + ExprSep + fi.name + } + + switch { + case fi.rel: + + case fi.reverse: + switch fi.reverseFieldInfo.fieldType { + case RelOneToOne, RelForeignKey: index = jtl.index - } - - info = fi - - if jtl == nil { - name = fi.name - } else { - name = jtl.name + ExprSep + fi.name - } - - switch { - case fi.rel: - - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { - case RelOneToOne, RelForeignKey: - index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name - } + info = fi.reverseFieldInfo.mi.fields.pk + name = info.name } } + break loopFor + } else { index = "" name = "" diff --git a/orm/orm_test.go b/orm/orm_test.go index bd4b6972..f5101811 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1561,6 +1561,32 @@ func TestDelete(t *testing.T) { num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + fmt.Println("...") + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) } func TestTransaction(t *testing.T) {