Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(firestore): adds COUNT aggregation query #6692

Merged
merged 13 commits into from Oct 7, 2022
50 changes: 50 additions & 0 deletions firestore/integration_test.go
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/api/option"
firestore "google.golang.org/genproto/googleapis/firestore/v1beta1"
"google.golang.org/genproto/googleapis/type/latlng"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1773,3 +1774,52 @@ func TestIntegration_BulkWriter(t *testing.T) {
}
}
}

func TestIntegration_CountAggregationQuery(t *testing.T) {
docs := []*DocumentRef{
iColl.NewDoc(),
iColl.NewDoc(),
}

c := integrationClient(t)
ctx := context.Background()
bw := c.BulkWriter(ctx)
jobs := make([]*BulkWriterJob, 0)

// Populate the collection
f := integrationTestMap
for _, d := range docs {
j, err := bw.Create(d, f)
jobs = append(jobs, j)
if err != nil {
t.Fatal(err)
}
}
bw.End()

for _, j := range jobs {
_, err := j.Results()
if err != nil {
t.Fatal(err)
}
}

// [START firestore_count_query]
alias := "twos"
q := iColl.Where("str", "==", "two")
aq := q.NewAggregationQuery()
ar, err := aq.WithCount(alias).Get(ctx)
// [END firestore_count_query]
if err != nil {
t.Fatal(err)
}

count, ok := ar[alias]
if !ok {
t.Errorf("key %s not in response %v", alias, ar)
}
cv := count.(*firestore.Value)
if cv.GetIntegerValue() != 2 {
t.Errorf("COUNT aggregation query mismatch;\ngot: %d, want: %d", cv.GetIntegerValue(), 2)
}
}
21 changes: 21 additions & 0 deletions firestore/mock_test.go
Expand Up @@ -187,6 +187,27 @@ func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryS
return nil
}

func (s *mockServer) RunAggregationQuery(req *pb.RunAggregationQueryRequest, qs pb.Firestore_RunAggregationQueryServer) error {
res, err := s.popRPC(req)
if err != nil {
return err
}
responses := res.([]interface{})
for _, res := range responses {
switch res := res.(type) {
case *pb.RunAggregationQueryResponse:
if err := qs.Send(res); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return qs.Send(res)

With a return in each branch, you can also remove the return nil outside of the statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiler doesn't like it when I remove the return statement, even when I add a return nil after the if in the first branch.

return err
}
case error:
return res
default:
return fmt.Errorf("bad response type in RunAggregationQuery: %+v", res)
}
}
return nil
}

func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) {
res, err := s.popRPC(req)
if err != nil {
Expand Down
79 changes: 79 additions & 0 deletions firestore/query.go
Expand Up @@ -304,6 +304,14 @@ func (q Query) Deserialize(bytes []byte) (Query, error) {
return q.fromProto(&runQueryRequest)
}

// NewAggregationQuery returns an AggregationQuery with this query as its
// base query.
func (q *Query) NewAggregationQuery() *AggregationQuery {
return &AggregationQuery{
query: q,
}
}

// fromProto creates a new Query object from a RunQueryRequest. This can be used
// in combination with ToProto to serialize Query objects. This could be useful,
// for instance, if executing a query formed in one process in another.
Expand Down Expand Up @@ -1036,3 +1044,74 @@ func (it *btreeDocumentIterator) next() (*DocumentSnapshot, error) {
}

func (*btreeDocumentIterator) stop() {}

// AggregationQuery allows for generating aggregation results of an underlying
// basic query. A single AggregationQuery can contain multiple aggregations.
type AggregationQuery struct {
// aggregateQueries contains all of the queries for this request.
aggregateQueries []*pb.StructuredAggregationQuery_Aggregation
// query contains a reference pointer to the underlying structured query.
query *Query
}

// WithCount specifies that the aggregation query provide a count of results
// returned by the underlying Query.
func (a *AggregationQuery) WithCount(alias string) *AggregationQuery {
aq := &pb.StructuredAggregationQuery_Aggregation{
Alias: alias,
Operator: &pb.StructuredAggregationQuery_Aggregation_Count_{},
}

a.aggregateQueries = append(a.aggregateQueries, aq)

return a
}

// Get retrieves the aggregation query results from the service.
func (a *AggregationQuery) Get(ctx context.Context) (AggregationResult, error) {
enocom marked this conversation as resolved.
Show resolved Hide resolved

client := a.query.c.c
q, err := a.query.toProto()
if err != nil {
return nil, err
}

req := &pb.RunAggregationQueryRequest{
Parent: a.query.parentPath,
QueryType: &pb.RunAggregationQueryRequest_StructuredAggregationQuery{
StructuredAggregationQuery: &pb.StructuredAggregationQuery{
QueryType: &pb.StructuredAggregationQuery_StructuredQuery{
StructuredQuery: q,
},
Aggregations: a.aggregateQueries,
},
},
}
ctx = withResourceHeader(ctx, a.query.c.path())
stream, err := client.RunAggregationQuery(ctx, req)
if err != nil {
return nil, err
}

resp := make(AggregationResult)

for {
res, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}

f := res.Result.AggregateFields

for k, v := range f {
resp[k] = v
}
}
return resp, nil
}

// AggregationResult contains the results of an aggregation query.
type AggregationResult map[string]interface{}
32 changes: 32 additions & 0 deletions firestore/query_test.go
Expand Up @@ -923,3 +923,35 @@ func (b byQuery) Less(i, j int) bool {
}
return c < 0
}

func TestAggregationQuery(t *testing.T) {
ctx := context.Background()
c, srv, cleanup := newMock(t)
defer cleanup()

srv.addRPC(nil, []interface{}{
&pb.RunAggregationQueryResponse{
Result: &pb.AggregationResult{
AggregateFields: map[string]*pb.Value{
"testAlias": intval(1),
},
},
},
})

q := c.Collection("coll1").Where("f", "==", 2)
ar, err := q.NewAggregationQuery().WithCount("testAlias").Get(ctx)
if err != nil {
t.Fatal(err)
}

count, ok := ar["testAlias"]
if !ok {
t.Errorf("aggregation query key not found")
}

cv := count.(*pb.Value)
if cv.GetIntegerValue() != 1 {
t.Errorf("got: %v\nwant: %v\n; result: %v\n", cv.GetIntegerValue(), 1, count)
}
}