diff --git a/orm/db.go b/orm/db.go index 10f65fee..677a20ef 100644 --- a/orm/db.go +++ b/orm/db.go @@ -800,6 +800,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi tables.parseRelated(qs.related, qs.relDepth) where, args := tables.getCondSql(cond, false, tz) + groupBy := tables.getGroupSql(qs.groups) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(mi, offset, rlimit) join := tables.getJoinSql() @@ -812,7 +813,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } } - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -936,12 +937,13 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition tables.parseRelated(qs.related, qs.relDepth) where, args := tables.getCondSql(cond, false, tz) + groupBy := tables.getGroupSql(qs.groups) tables.getOrderSql(qs.orders) join := tables.getJoinSql() Q := d.ins.TableQuote() - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where) + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) d.ins.ReplaceMarks(&query) @@ -1442,13 +1444,14 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond } where, args := tables.getCondSql(cond, false, tz) + groupBy := tables.getGroupSql(qs.groups) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(mi, qs.offset, qs.limit) join := tables.getJoinSql() sels := strings.Join(cols, ", ") - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) diff --git a/orm/db_tables.go b/orm/db_tables.go index a9aa10ab..3677932a 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -390,6 +390,30 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe return } +// generate group sql. +func (t *dbTables) getGroupSql(groups []string) (groupSql string) { + if len(groups) == 0 { + return + } + + Q := t.base.TableQuote() + + groupSqls := make([]string, 0, len(groups)) + for _, group := range groups { + exprs := strings.Split(group, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + } + + groupSql = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) + return +} + // generate order sql. func (t *dbTables) getOrderSql(orders []string) (orderSql string) { if len(orders) == 0 { diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 4f5d5485..26e82379 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -60,6 +60,7 @@ type querySet struct { relDepth int limit int64 offset int64 + groups []string orders []string orm *orm } @@ -105,6 +106,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter { return &o } +// add GROUP expression. +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + // add ORDER expression. // "column" means ASC, "-column" means DESC. func (o querySet) OrderBy(exprs ...string) QuerySeter { diff --git a/orm/orm_test.go b/orm/orm_test.go index e1c8e0f0..c223a463 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -845,6 +845,27 @@ func TestOffset(t *testing.T) { throwFail(t, AssertIs(num, 2)) } +func TestGroupBy(t *testing.T) { + var users []*User + var maps []Params + qs := dORM.QueryTable("user") + num, err := qs.GroupBy("is_staff").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.GroupBy("is_staff").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.GroupBy("is_staff").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.GroupBy("profile__age").Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + func TestOrderBy(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() diff --git a/orm/types.go b/orm/types.go index c342e1c2..d308642b 100644 --- a/orm/types.go +++ b/orm/types.go @@ -67,6 +67,7 @@ type QuerySeter interface { SetCond(*Condition) QuerySeter Limit(interface{}, ...interface{}) QuerySeter Offset(interface{}) QuerySeter + GroupBy(...string) QuerySeter OrderBy(...string) QuerySeter RelatedSel(...interface{}) QuerySeter Count() (int64, error)