Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support "limit" in count query. #384

Merged
merged 9 commits into from Nov 30, 2022
14 changes: 10 additions & 4 deletions google/cloud/datastore/aggregation.py
Expand Up @@ -58,15 +58,18 @@ class CountAggregation(BaseAggregation):

"""

def __init__(self, alias=None):
def __init__(self, alias=None, limit=None):
self.alias = alias
self.limit = limit

def _to_pb(self):
"""
Convert this instance to the protobuf representation
"""
aggregation_pb = query_pb2.AggregationQuery.Aggregation()
aggregation_pb.count = query_pb2.AggregationQuery.Aggregation.Count()
if self.limit is not None and self.limit > 0:
aggregation_pb.count.up_to = self.limit
aggregation_pb.alias = self.alias
return aggregation_pb

Expand Down Expand Up @@ -143,14 +146,17 @@ def _to_pb(self):
pb.aggregations.append(aggregation_pb)
return pb

def count(self, alias=None):
def count(self, alias=None, limit=None):
Mariatta marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds a count over the nested query

:type alias: str
:param alias: (Optional) The alias for the count

:type limit: int
:param limit: (Optional) The limit for the count
"""
count_aggregation = CountAggregation(alias=alias)
count_aggregation = CountAggregation(alias=alias, limit=limit)
self._aggregations.append(count_aggregation)
return self

Expand Down Expand Up @@ -204,7 +210,7 @@ def fetch(
>>> client.put_multi([andy, sally, bobby])
>>> query = client.query(kind='Andy')
>>> aggregation_query = client.aggregation_query(query)
>>> result = aggregation_query.count(alias="total").fetch()
>>> result = aggregation_query.count(alias="total", limit=5).fetch()
>>> result
<google.cloud.datastore.aggregation.AggregationResultIterator object at ...>

Expand Down
20 changes: 20 additions & 0 deletions tests/system/test_aggregation_query.py
Expand Up @@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
assert r.value > 0


def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "total"
assert r.value > 0

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total_up_to", limit=2) # count with limit = 1
Mariatta marked this conversation as resolved.
Show resolved Hide resolved
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "total_up_to"
assert r.value == 2


def test_aggregation_query_multiple_aggregations(
aggregation_query_client, nested_query
):
Expand Down
18 changes: 11 additions & 7 deletions tests/unit/test_aggregation.py
Expand Up @@ -25,10 +25,12 @@
def test_count_aggregation_to_pb():
from google.cloud.datastore_v1.types import query as query_pb2

count_aggregation = CountAggregation(alias="total")
count_aggregation = CountAggregation(alias="total", limit=5)

expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation()
expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count()
expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count(
up_to=5
)
expected_aggregation_query_pb.alias = count_aggregation.alias
assert count_aggregation._to_pb() == expected_aggregation_query_pb

Expand All @@ -54,11 +56,11 @@ def test_pb_over_query_with_count(client):
query = _make_query(client)
aggregation_query = _make_aggregation_query(client=client, query=query)

aggregation_query.count(alias="total")
aggregation_query.count(alias="total", limit=5)
pb = aggregation_query._to_pb()
assert pb.nested_query == _pb_from_query(query)
assert len(pb.aggregations) == 1
assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb()
assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb()


def test_pb_over_query_with_add_aggregation(client):
Expand All @@ -67,11 +69,11 @@ def test_pb_over_query_with_add_aggregation(client):
query = _make_query(client)
aggregation_query = _make_aggregation_query(client=client, query=query)

aggregation_query.add_aggregation(CountAggregation(alias="total"))
aggregation_query.add_aggregation(CountAggregation(alias="total", limit=5))
pb = aggregation_query._to_pb()
assert pb.nested_query == _pb_from_query(query)
assert len(pb.aggregations) == 1
assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb()
assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb()


def test_pb_over_query_with_add_aggregations(client):
Expand All @@ -80,6 +82,7 @@ def test_pb_over_query_with_add_aggregations(client):
aggregations = [
CountAggregation(alias="total"),
CountAggregation(alias="all"),
CountAggregation(limit=5),
]

query = _make_query(client)
Expand All @@ -88,9 +91,10 @@ def test_pb_over_query_with_add_aggregations(client):
aggregation_query.add_aggregations(aggregations)
pb = aggregation_query._to_pb()
assert pb.nested_query == _pb_from_query(query)
assert len(pb.aggregations) == 2
assert len(pb.aggregations) == 3
assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb()
assert pb.aggregations[1] == CountAggregation(alias="all")._to_pb()
assert pb.aggregations[2] == CountAggregation(limit=5)._to_pb()


def test_query_fetch_defaults_w_client_attr(client):
Expand Down