Skip to content

Commit

Permalink
feat(firestore): adds COUNT aggregation query (#6692)
Browse files Browse the repository at this point in the history
* feat(firestore): adds COUNT aggregation query
  • Loading branch information
Eric Schmidt committed Oct 7, 2022
1 parent 09e05f9 commit 31ac692
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 0 deletions.
50 changes: 50 additions & 0 deletions firestore/integration_test.go
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/api/option"
firestore "google.golang.org/genproto/googleapis/firestore/v1beta1"
"google.golang.org/genproto/googleapis/type/latlng"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1773,3 +1774,52 @@ func TestIntegration_BulkWriter(t *testing.T) {
}
}
}

func TestIntegration_CountAggregationQuery(t *testing.T) {
docs := []*DocumentRef{
iColl.NewDoc(),
iColl.NewDoc(),
}

c := integrationClient(t)
ctx := context.Background()
bw := c.BulkWriter(ctx)
jobs := make([]*BulkWriterJob, 0)

// Populate the collection
f := integrationTestMap
for _, d := range docs {
j, err := bw.Create(d, f)
jobs = append(jobs, j)
if err != nil {
t.Fatal(err)
}
}
bw.End()

for _, j := range jobs {
_, err := j.Results()
if err != nil {
t.Fatal(err)
}
}

// [START firestore_count_query]
alias := "twos"
q := iColl.Where("str", "==", "two")
aq := q.NewAggregationQuery()
ar, err := aq.WithCount(alias).Get(ctx)
// [END firestore_count_query]
if err != nil {
t.Fatal(err)
}

count, ok := ar[alias]
if !ok {
t.Errorf("key %s not in response %v", alias, ar)
}
cv := count.(*firestore.Value)
if cv.GetIntegerValue() != 2 {
t.Errorf("COUNT aggregation query mismatch;\ngot: %d, want: %d", cv.GetIntegerValue(), 2)
}
}
21 changes: 21 additions & 0 deletions firestore/mock_test.go
Expand Up @@ -187,6 +187,27 @@ func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryS
return nil
}

func (s *mockServer) RunAggregationQuery(req *pb.RunAggregationQueryRequest, qs pb.Firestore_RunAggregationQueryServer) error {
res, err := s.popRPC(req)
if err != nil {
return err
}
responses := res.([]interface{})
for _, res := range responses {
switch res := res.(type) {
case *pb.RunAggregationQueryResponse:
if err := qs.Send(res); err != nil {
return err
}
case error:
return res
default:
return fmt.Errorf("bad response type in RunAggregationQuery: %+v", res)
}
}
return nil
}

func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) {
res, err := s.popRPC(req)
if err != nil {
Expand Down
79 changes: 79 additions & 0 deletions firestore/query.go
Expand Up @@ -304,6 +304,14 @@ func (q Query) Deserialize(bytes []byte) (Query, error) {
return q.fromProto(&runQueryRequest)
}

// NewAggregationQuery returns an AggregationQuery with this query as its
// base query.
func (q *Query) NewAggregationQuery() *AggregationQuery {
return &AggregationQuery{
query: q,
}
}

// fromProto creates a new Query object from a RunQueryRequest. This can be used
// in combination with ToProto to serialize Query objects. This could be useful,
// for instance, if executing a query formed in one process in another.
Expand Down Expand Up @@ -1036,3 +1044,74 @@ func (it *btreeDocumentIterator) next() (*DocumentSnapshot, error) {
}

func (*btreeDocumentIterator) stop() {}

// AggregationQuery allows for generating aggregation results of an underlying
// basic query. A single AggregationQuery can contain multiple aggregations.
type AggregationQuery struct {
// aggregateQueries contains all of the queries for this request.
aggregateQueries []*pb.StructuredAggregationQuery_Aggregation
// query contains a reference pointer to the underlying structured query.
query *Query
}

// WithCount specifies that the aggregation query provide a count of results
// returned by the underlying Query.
func (a *AggregationQuery) WithCount(alias string) *AggregationQuery {
aq := &pb.StructuredAggregationQuery_Aggregation{
Alias: alias,
Operator: &pb.StructuredAggregationQuery_Aggregation_Count_{},
}

a.aggregateQueries = append(a.aggregateQueries, aq)

return a
}

// Get retrieves the aggregation query results from the service.
func (a *AggregationQuery) Get(ctx context.Context) (AggregationResult, error) {

client := a.query.c.c
q, err := a.query.toProto()
if err != nil {
return nil, err
}

req := &pb.RunAggregationQueryRequest{
Parent: a.query.parentPath,
QueryType: &pb.RunAggregationQueryRequest_StructuredAggregationQuery{
StructuredAggregationQuery: &pb.StructuredAggregationQuery{
QueryType: &pb.StructuredAggregationQuery_StructuredQuery{
StructuredQuery: q,
},
Aggregations: a.aggregateQueries,
},
},
}
ctx = withResourceHeader(ctx, a.query.c.path())
stream, err := client.RunAggregationQuery(ctx, req)
if err != nil {
return nil, err
}

resp := make(AggregationResult)

for {
res, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}

f := res.Result.AggregateFields

for k, v := range f {
resp[k] = v
}
}
return resp, nil
}

// AggregationResult contains the results of an aggregation query.
type AggregationResult map[string]interface{}
32 changes: 32 additions & 0 deletions firestore/query_test.go
Expand Up @@ -923,3 +923,35 @@ func (b byQuery) Less(i, j int) bool {
}
return c < 0
}

func TestAggregationQuery(t *testing.T) {
ctx := context.Background()
c, srv, cleanup := newMock(t)
defer cleanup()

srv.addRPC(nil, []interface{}{
&pb.RunAggregationQueryResponse{
Result: &pb.AggregationResult{
AggregateFields: map[string]*pb.Value{
"testAlias": intval(1),
},
},
},
})

q := c.Collection("coll1").Where("f", "==", 2)
ar, err := q.NewAggregationQuery().WithCount("testAlias").Get(ctx)
if err != nil {
t.Fatal(err)
}

count, ok := ar["testAlias"]
if !ok {
t.Errorf("aggregation query key not found")
}

cv := count.(*pb.Value)
if cv.GetIntegerValue() != 1 {
t.Errorf("got: %v\nwant: %v\n; result: %v\n", cv.GetIntegerValue(), 1, count)
}
}

0 comments on commit 31ac692

Please sign in to comment.