Skip to content

Commit

Permalink
feat(datastore): adds COUNT aggregation query (#6714)
Browse files Browse the repository at this point in the history
* feat(datastore): adds COUNT aggregation query
  • Loading branch information
Eric Schmidt committed Oct 26, 2022
1 parent 78621ce commit 27363ca
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 35 deletions.
2 changes: 1 addition & 1 deletion datastore/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/google/go-cmp v0.5.9
github.com/googleapis/gax-go/v2 v2.6.0
google.golang.org/api v0.99.0
google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e
google.golang.org/genproto v0.0.0-20221014213838-99cd37c6964a
google.golang.org/grpc v1.50.1
google.golang.org/protobuf v1.28.1
)
Expand Down
4 changes: 2 additions & 2 deletions datastore/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e h1:halCgTFuLWDRD61piiNSxPsARANGD3Xl16hPrLgLiIg=
google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e/go.mod h1:3526vdqwhZAwq4wsRUaVG555sVgsNmIjRtO7t/JH29U=
google.golang.org/genproto v0.0.0-20221014213838-99cd37c6964a h1:GH6UPn3ixhWcKDhpnEC55S75cerLPdpp3hrhfKYjZgw=
google.golang.org/genproto v0.0.0-20221014213838-99cd37c6964a/go.mod h1:1vXfmgAz9N9Jx0QA82PqRVauvCz1SGSz739p0f183jM=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
Expand Down
165 changes: 137 additions & 28 deletions datastore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,28 @@ func (q *Query) End(c Cursor) *Query {
return q
}

// toProto converts the query to a protocol buffer.
func (q *Query) toProto(req *pb.RunQueryRequest) error {
// toRunQueryRequest converts the query to a protocol buffer.
func (q *Query) toRunQueryRequest(req *pb.RunQueryRequest) error {
dst, err := q.toProto()
if err != nil {
return err
}

req.ReadOptions, err = parseReadOptions(q)
if err != nil {
return err
}

req.QueryType = &pb.RunQueryRequest_Query{Query: dst}
return nil
}

func (q *Query) toProto() (*pb.Query, error) {
if len(q.projection) != 0 && q.keysOnly {
return errors.New("datastore: query cannot both project and be keys-only")
return nil, errors.New("datastore: query cannot both project and be keys-only")
}
if len(q.distinctOn) != 0 && q.distinct {
return errors.New("datastore: query cannot be both distinct and distinct-on")
return nil, errors.New("datastore: query cannot be both distinct and distinct-on")
}
dst := &pb.Query{}
if q.kind != "" {
Expand All @@ -394,19 +409,18 @@ func (q *Query) toProto(req *pb.RunQueryRequest) error {
if q.keysOnly {
dst.Projection = []*pb.Projection{{Property: &pb.PropertyReference{Name: keyFieldName}}}
}

var filters []*pb.Filter
for _, qf := range q.filter {
if qf.FieldName == "" {
return errors.New("datastore: empty query filter field name")
return nil, errors.New("datastore: empty query filter field name")
}
v, err := interfaceToProto(reflect.ValueOf(qf.Value).Interface(), false)
if err != nil {
return fmt.Errorf("datastore: bad query filter value type: %v", err)
return nil, fmt.Errorf("datastore: bad query filter value type: %v", err)
}
op, ok := operatorToProto[qf.Op]
if !ok {
return errors.New("datastore: unknown query filter operator")
return nil, errors.New("datastore: unknown query filter operator")
}
xf := &pb.PropertyFilter{
Op: op,
Expand Down Expand Up @@ -438,7 +452,7 @@ func (q *Query) toProto(req *pb.RunQueryRequest) error {

for _, qo := range q.order {
if qo.FieldName == "" {
return errors.New("datastore: empty query order field name")
return nil, errors.New("datastore: empty query order field name")
}
xo := &pb.PropertyOrder{
Property: &pb.PropertyReference{Name: qo.FieldName},
Expand All @@ -453,24 +467,7 @@ func (q *Query) toProto(req *pb.RunQueryRequest) error {
dst.StartCursor = q.start
dst.EndCursor = q.end

if t := q.trans; t != nil {
if t.id == nil {
return errExpiredTransaction
}
if q.eventual {
return errors.New("datastore: cannot use EventualConsistency query in a transaction")
}
req.ReadOptions = &pb.ReadOptions{
ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id},
}
}

if q.eventual {
req.ReadOptions = &pb.ReadOptions{ConsistencyType: &pb.ReadOptions_ReadConsistency_{ReadConsistency: pb.ReadOptions_EVENTUAL}}
}

req.QueryType = &pb.RunQueryRequest_Query{Query: dst}
return nil
return dst, nil
}

// Count returns the number of results for the given query.
Expand All @@ -479,6 +476,8 @@ func (q *Query) toProto(req *pb.RunQueryRequest) error {
// the sum of the query's offset and limit. Unless the result count is
// expected to be small, it is best to specify a limit; otherwise Count will
// continue until it finishes counting or the provided context expires.
//
// Deprecated. Use Client.RunAggregationQuery() instead.
func (c *Client) Count(ctx context.Context, q *Query) (n int, err error) {
ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.Count")
defer func() { trace.EndSpan(ctx, err) }()
Expand Down Expand Up @@ -620,12 +619,85 @@ func (c *Client) Run(ctx context.Context, q *Query) *Iterator {
}
}

if err := q.toProto(t.req); err != nil {
if err := q.toRunQueryRequest(t.req); err != nil {
t.err = err
}
return t
}

// RunAggregationQuery gets aggregation query (e.g. COUNT) results from the service.
func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) (AggregationResult, error) {
if len(aq.aggregationQueries) == 0 {
return nil, errors.New("datastore: aggregation query must contain one or more operators (e.g. count)")
}

q, err := aq.query.toProto()
if err != nil {
return nil, err
}

req := &pb.RunAggregationQueryRequest{
ProjectId: c.dataset,
QueryType: &pb.RunAggregationQueryRequest_AggregationQuery{
AggregationQuery: &pb.AggregationQuery{
QueryType: &pb.AggregationQuery_NestedQuery{
NestedQuery: q,
},
Aggregations: aq.aggregationQueries,
},
},
}

if aq.query.namespace != "" {
req.PartitionId = &pb.PartitionId{
NamespaceId: aq.query.namespace,
}
}

// Parse the read options.
req.ReadOptions, err = parseReadOptions(aq.query)
if err != nil {
return nil, err
}

res, err := c.client.RunAggregationQuery(ctx, req)
if err != nil {
return nil, err
}

ar := make(AggregationResult)

// TODO(developer): change batch parsing logic if other aggregations are supported.
for _, a := range res.Batch.AggregationResults {
for k, v := range a.AggregateProperties {
ar[k] = v
}
}

return ar, nil
}

// parseReadOptions translates Query read options into protobuf format.
func parseReadOptions(q *Query) (*pb.ReadOptions, error) {
if t := q.trans; t != nil {
if t.id == nil {
return nil, errExpiredTransaction
}
if q.eventual {
return nil, errors.New("datastore: cannot use EventualConsistency query in a transaction")
}
return &pb.ReadOptions{
ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id},
}, nil
}

if q.eventual {
return &pb.ReadOptions{ConsistencyType: &pb.ReadOptions_ReadConsistency_{ReadConsistency: pb.ReadOptions_EVENTUAL}}, nil
}

return nil, nil
}

// Iterator is the result of running a query.
//
// It is not safe for concurrent use.
Expand Down Expand Up @@ -819,3 +891,40 @@ func DecodeCursor(s string) (Cursor, error) {
}
return Cursor{b}, nil
}

// NewAggregationQuery returns an AggregationQuery with this query as its
// base query.
func (q *Query) NewAggregationQuery() *AggregationQuery {
q.eventual = true
return &AggregationQuery{
query: q,
aggregationQueries: make([]*pb.AggregationQuery_Aggregation, 0),
}
}

// AggregationQuery allows for generating aggregation results of an underlying
// basic query. A single AggregationQuery can contain multiple aggregations.
type AggregationQuery struct {
query *Query // query contains a reference pointer to the underlying structured query.
aggregationQueries []*pb.AggregationQuery_Aggregation // aggregateQueries contains all of the queries for this request.
}

// WithCount specifies that the aggregation query provide a count of results
// returned by the underlying Query.
func (aq *AggregationQuery) WithCount(alias string) *AggregationQuery {
if alias == "" {
alias = fmt.Sprintf("%s_%s", "count", aq.query.kind)
}

aqpb := &pb.AggregationQuery_Aggregation{
Alias: alias,
Operator: &pb.AggregationQuery_Aggregation_Count_{},
}

aq.aggregationQueries = append(aq.aggregationQueries, aqpb)

return aq
}

// AggregationResult contains the results of an aggregation query.
type AggregationResult map[string]interface{}
88 changes: 84 additions & 4 deletions datastore/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ var (
},
},
}
countAlias = "count"
)

type fakeClient struct {
pb.DatastoreClient
queryFn func(*pb.RunQueryRequest) (*pb.RunQueryResponse, error)
commitFn func(*pb.CommitRequest) (*pb.CommitResponse, error)
queryFn func(*pb.RunQueryRequest) (*pb.RunQueryResponse, error)
commitFn func(*pb.CommitRequest) (*pb.CommitResponse, error)
aggQueryFn func(*pb.RunAggregationQueryRequest) (*pb.RunAggregationQueryResponse, error)
}

func (c *fakeClient) RunQuery(_ context.Context, req *pb.RunQueryRequest, _ ...grpc.CallOption) (*pb.RunQueryResponse, error) {
Expand All @@ -66,6 +68,10 @@ func (c *fakeClient) Commit(_ context.Context, req *pb.CommitRequest, _ ...grpc.
return c.commitFn(req)
}

func (c *fakeClient) RunAggregationQuery(_ context.Context, req *pb.RunAggregationQueryRequest, _ ...grpc.CallOption) (*pb.RunAggregationQueryResponse, error) {
return c.aggQueryFn(req)
}

func fakeRunQuery(in *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
expectedIn := &pb.RunQueryRequest{
QueryType: &pb.RunQueryRequest_Query{Query: &pb.Query{
Expand Down Expand Up @@ -103,6 +109,47 @@ func fakeRunQuery(in *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
}, nil
}

func fakeRunAggregationQuery(req *pb.RunAggregationQueryRequest) (*pb.RunAggregationQueryResponse, error) {
expectedIn := &pb.RunAggregationQueryRequest{
QueryType: &pb.RunAggregationQueryRequest_AggregationQuery{
AggregationQuery: &pb.AggregationQuery{
QueryType: &pb.AggregationQuery_NestedQuery{
NestedQuery: &pb.Query{
Kind: []*pb.KindExpression{{Name: "Gopher"}},
},
},
Aggregations: []*pb.AggregationQuery_Aggregation{
{
Operator: &pb.AggregationQuery_Aggregation_Count_{},
Alias: countAlias,
},
},
},
},
ReadOptions: &pb.ReadOptions{
ConsistencyType: &pb.ReadOptions_ReadConsistency_{
ReadConsistency: pb.ReadOptions_EVENTUAL,
},
},
}
if !proto.Equal(req, expectedIn) {
return nil, fmt.Errorf("unsupported argument: got %v want %v", req, expectedIn)
}
return &pb.RunAggregationQueryResponse{
Batch: &pb.AggregationResultBatch{
AggregationResults: []*pb.AggregationResult{
{
AggregateProperties: map[string]*pb.Value{
"count": {
ValueType: &pb.Value_IntegerValue{IntegerValue: 1},
},
},
},
},
},
}, nil
}

type StructThatImplementsPLS struct{}

func (StructThatImplementsPLS) Load(p []Property) error { return nil }
Expand Down Expand Up @@ -600,7 +647,7 @@ func TestReadOptions(t *testing.T) {
},
} {
req := &pb.RunQueryRequest{}
if err := test.q.toProto(req); err != nil {
if err := test.q.toRunQueryRequest(req); err != nil {
t.Fatalf("%+v: got %v, want no error", test.q, err)
}
if got := req.ReadOptions; !proto.Equal(got, test.want) {
Expand All @@ -613,7 +660,7 @@ func TestReadOptions(t *testing.T) {
NewQuery("").Transaction(&Transaction{id: tid}).EventualConsistency(),
} {
req := &pb.RunQueryRequest{}
if err := q.toProto(req); err == nil {
if err := q.toRunQueryRequest(req); err == nil {
t.Errorf("%+v: got nil, wanted error", q)
}
}
Expand Down Expand Up @@ -641,3 +688,36 @@ func TestInvalidFilters(t *testing.T) {
}
}
}

func TestAggregationQuery(t *testing.T) {
client := &Client{
client: &fakeClient{
aggQueryFn: func(req *pb.RunAggregationQueryRequest) (*pb.RunAggregationQueryResponse, error) {
return fakeRunAggregationQuery(req)
},
},
}

q := NewQuery("Gopher")
aq := q.NewAggregationQuery()
aq.WithCount(countAlias)

res, err := client.RunAggregationQuery(context.Background(), aq)
if err != nil {
t.Fatal(err)
}

count, ok := res[countAlias]
if !ok {
t.Errorf("%s key does not exist in return aggregation result", countAlias)
}

want := &pb.Value{
ValueType: &pb.Value_IntegerValue{IntegerValue: 1},
}

cv := count.(*pb.Value)
if !proto.Equal(want, cv) {
t.Errorf("want: %v\ngot: %v\n", want, cv)
}
}

0 comments on commit 27363ca

Please sign in to comment.