From e10ed3d0b9c63ce005adebd8afd911239b0261e2 Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Thu, 10 Nov 2022 16:47:41 -0800 Subject: [PATCH 1/7] feat: Support "limit" in count query. --- google/cloud/datastore/aggregation.py | 11 ++++++++--- tests/system/test_aggregation_query.py | 17 +++++++++++++++++ tests/unit/test_aggregation.py | 18 +++++++++++------- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index bb75d94e..e4090f18 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -58,8 +58,9 @@ class CountAggregation(BaseAggregation): """ - def __init__(self, alias=None): + def __init__(self, alias=None, limit=None): self.alias = alias + self.limit = limit def _to_pb(self): """ @@ -67,6 +68,7 @@ def _to_pb(self): """ aggregation_pb = query_pb2.AggregationQuery.Aggregation() aggregation_pb.count = query_pb2.AggregationQuery.Aggregation.Count() + aggregation_pb.count.up_to = self.limit aggregation_pb.alias = self.alias return aggregation_pb @@ -143,14 +145,17 @@ def _to_pb(self): pb.aggregations.append(aggregation_pb) return pb - def count(self, alias=None): + def count(self, alias=None, limit=None): """ Adds a count over the nested query :type alias: str :param alias: (Optional) The alias for the count + + :type limit: int + :param limit: (Optional) The limit for the count """ - count_aggregation = CountAggregation(alias=alias) + count_aggregation = CountAggregation(alias=alias, limit=limit) self._aggregations.append(count_aggregation) return self diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index 3e5120da..f7769388 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -93,6 +93,23 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query): assert r.value > 0 +def test_aggregation_query_with_limit(aggregation_query_client, nested_query): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count() # count without limit + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.value > 1 + + aggregation_query.count(limit=1) # count with limit = 1 + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.value == 1 + + def test_aggregation_query_multiple_aggregations( aggregation_query_client, nested_query ): diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index 8b28a908..a99f3083 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -25,10 +25,12 @@ def test_count_aggregation_to_pb(): from google.cloud.datastore_v1.types import query as query_pb2 - count_aggregation = CountAggregation(alias="total") + count_aggregation = CountAggregation(alias="total", limit=5) expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation() - expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count() + expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count( + up_to=5 + ) expected_aggregation_query_pb.alias = count_aggregation.alias assert count_aggregation._to_pb() == expected_aggregation_query_pb @@ -54,11 +56,11 @@ def test_pb_over_query_with_count(client): query = _make_query(client) aggregation_query = _make_aggregation_query(client=client, query=query) - aggregation_query.count(alias="total") + aggregation_query.count(alias="total", limit=5) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) assert len(pb.aggregations) == 1 - assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() + assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb() def test_pb_over_query_with_add_aggregation(client): @@ -67,11 +69,11 @@ def test_pb_over_query_with_add_aggregation(client): query = _make_query(client) aggregation_query = _make_aggregation_query(client=client, query=query) - aggregation_query.add_aggregation(CountAggregation(alias="total")) + aggregation_query.add_aggregation(CountAggregation(alias="total", limit=5)) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) assert len(pb.aggregations) == 1 - assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() + assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb() def test_pb_over_query_with_add_aggregations(client): @@ -80,6 +82,7 @@ def test_pb_over_query_with_add_aggregations(client): aggregations = [ CountAggregation(alias="total"), CountAggregation(alias="all"), + CountAggregation(limit=5), ] query = _make_query(client) @@ -88,9 +91,10 @@ def test_pb_over_query_with_add_aggregations(client): aggregation_query.add_aggregations(aggregations) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) - assert len(pb.aggregations) == 2 + assert len(pb.aggregations) == 3 assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() assert pb.aggregations[1] == CountAggregation(alias="all")._to_pb() + assert pb.aggregations[2] == CountAggregation(limit=5)._to_pb() def test_query_fetch_defaults_w_client_attr(client): From 1caa958bfa6668585ea7fb94dc91382406427b44 Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Thu, 10 Nov 2022 18:23:43 -0800 Subject: [PATCH 2/7] Only add limit if exists --- google/cloud/datastore/aggregation.py | 6 ++++-- tests/system/test_aggregation_query.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index e4090f18..400c93e4 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -68,7 +68,8 @@ def _to_pb(self): """ aggregation_pb = query_pb2.AggregationQuery.Aggregation() aggregation_pb.count = query_pb2.AggregationQuery.Aggregation.Count() - aggregation_pb.count.up_to = self.limit + if self.limit is not None and self.limit > 0: + aggregation_pb.count.up_to = self.limit aggregation_pb.alias = self.alias return aggregation_pb @@ -209,7 +210,7 @@ def fetch( >>> client.put_multi([andy, sally, bobby]) >>> query = client.query(kind='Andy') >>> aggregation_query = client.aggregation_query(query) - >>> result = aggregation_query.count(alias="total").fetch() + >>> result = aggregation_query.count(alias="total", limit=5).fetch() >>> result @@ -363,6 +364,7 @@ def _next_page(self): return None query_pb = self._build_protobuf() + breakpoint() transaction = self.client.current_transaction if transaction is None: transaction_id = None diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index f7769388..475d6b8b 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -97,11 +97,12 @@ def test_aggregation_query_with_limit(aggregation_query_client, nested_query): query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) - aggregation_query.count() # count without limit + aggregation_query.count(alias="total") result = _do_fetch(aggregation_query) assert len(result) == 1 for r in result[0]: - assert r.value > 1 + assert r.alias == "total" + assert r.value > 0 aggregation_query.count(limit=1) # count with limit = 1 result = _do_fetch(aggregation_query) From 2d833ae650e8680fce59d8db752d352618b1676f Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Thu, 10 Nov 2022 18:30:40 -0800 Subject: [PATCH 3/7] remove breakpoint --- google/cloud/datastore/aggregation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index 400c93e4..a7bf7bb5 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -364,7 +364,6 @@ def _next_page(self): return None query_pb = self._build_protobuf() - breakpoint() transaction = self.client.current_transaction if transaction is None: transaction_id = None From 2417d5765f3126bb0a3efeb2161624722987b5b5 Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Thu, 10 Nov 2022 18:53:07 -0800 Subject: [PATCH 4/7] Fix system test for count with limit --- tests/system/test_aggregation_query.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index 475d6b8b..fb6ee559 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -104,11 +104,13 @@ def test_aggregation_query_with_limit(aggregation_query_client, nested_query): assert r.alias == "total" assert r.value > 0 - aggregation_query.count(limit=1) # count with limit = 1 + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total_up_to", limit=2) # count with limit = 1 result = _do_fetch(aggregation_query) assert len(result) == 1 for r in result[0]: - assert r.value == 1 + assert r.alias == "total_up_to" + assert r.value == 2 def test_aggregation_query_multiple_aggregations( From 88f68ece8638badb068d698c2560e0614353ec6d Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Mon, 14 Nov 2022 14:37:55 -0800 Subject: [PATCH 5/7] Add test for count without limit, and with limit --- tests/system/test_aggregation_query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index fb6ee559..020cbe27 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -98,14 +98,14 @@ def test_aggregation_query_with_limit(aggregation_query_client, nested_query): aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") - result = _do_fetch(aggregation_query) + result = _do_fetch(aggregation_query) # count without limit assert len(result) == 1 for r in result[0]: assert r.alias == "total" - assert r.value > 0 + assert r.value == 8 aggregation_query = aggregation_query_client.aggregation_query(query) - aggregation_query.count(alias="total_up_to", limit=2) # count with limit = 1 + aggregation_query.count(alias="total_up_to", limit=2) # count with limit = 2 result = _do_fetch(aggregation_query) assert len(result) == 1 for r in result[0]: From 3d71cf8b2b6ee7060ff5e4e4a071d666b0b79237 Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Tue, 15 Nov 2022 12:28:47 -0800 Subject: [PATCH 6/7] Move the limit to aggregation_query.fetch --- google/cloud/datastore/aggregation.py | 21 +++++++++++---------- tests/system/test_aggregation_query.py | 4 ++-- tests/unit/test_aggregation.py | 18 +++++++----------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index a7bf7bb5..6d68246c 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -58,9 +58,8 @@ class CountAggregation(BaseAggregation): """ - def __init__(self, alias=None, limit=None): + def __init__(self, alias=None): self.alias = alias - self.limit = limit def _to_pb(self): """ @@ -68,8 +67,6 @@ def _to_pb(self): """ aggregation_pb = query_pb2.AggregationQuery.Aggregation() aggregation_pb.count = query_pb2.AggregationQuery.Aggregation.Count() - if self.limit is not None and self.limit > 0: - aggregation_pb.count.up_to = self.limit aggregation_pb.alias = self.alias return aggregation_pb @@ -146,17 +143,14 @@ def _to_pb(self): pb.aggregations.append(aggregation_pb) return pb - def count(self, alias=None, limit=None): + def count(self, alias=None): """ Adds a count over the nested query :type alias: str :param alias: (Optional) The alias for the count - - :type limit: int - :param limit: (Optional) The limit for the count """ - count_aggregation = CountAggregation(alias=alias, limit=limit) + count_aggregation = CountAggregation(alias=alias) self._aggregations.append(count_aggregation) return self @@ -180,6 +174,7 @@ def add_aggregations(self, aggregations): def fetch( self, client=None, + limit=None, eventual=False, retry=None, timeout=None, @@ -210,7 +205,7 @@ def fetch( >>> client.put_multi([andy, sally, bobby]) >>> query = client.query(kind='Andy') >>> aggregation_query = client.aggregation_query(query) - >>> result = aggregation_query.count(alias="total", limit=5).fetch() + >>> result = aggregation_query.count(alias="total").fetch(fetch=5) >>> result @@ -254,6 +249,7 @@ def fetch( return AggregationResultIterator( self, client, + limit=limit, eventual=eventual, retry=retry, timeout=timeout, @@ -299,6 +295,7 @@ def __init__( self, aggregation_query, client, + limit=None, eventual=False, retry=None, timeout=None, @@ -314,6 +311,7 @@ def __init__( self._retry = retry self._timeout = timeout self._read_time = read_time + self._limit = limit # The attributes below will change over the life of the iterator. self._more_results = True @@ -328,6 +326,9 @@ def _build_protobuf(self): state of the iterator. """ pb = self._aggregation_query._to_pb() + if self._limit is not None and self._limit > 0: + for aggregation in pb.aggregations: + aggregation.count.up_to = self._limit return pb def _process_query_results(self, response_pb): diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index 020cbe27..b912e96b 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -105,8 +105,8 @@ def test_aggregation_query_with_limit(aggregation_query_client, nested_query): assert r.value == 8 aggregation_query = aggregation_query_client.aggregation_query(query) - aggregation_query.count(alias="total_up_to", limit=2) # count with limit = 2 - result = _do_fetch(aggregation_query) + aggregation_query.count(alias="total_up_to") + result = _do_fetch(aggregation_query, limit=2) # count with limit = 2 assert len(result) == 1 for r in result[0]: assert r.alias == "total_up_to" diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index a99f3083..8b28a908 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -25,12 +25,10 @@ def test_count_aggregation_to_pb(): from google.cloud.datastore_v1.types import query as query_pb2 - count_aggregation = CountAggregation(alias="total", limit=5) + count_aggregation = CountAggregation(alias="total") expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation() - expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count( - up_to=5 - ) + expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count() expected_aggregation_query_pb.alias = count_aggregation.alias assert count_aggregation._to_pb() == expected_aggregation_query_pb @@ -56,11 +54,11 @@ def test_pb_over_query_with_count(client): query = _make_query(client) aggregation_query = _make_aggregation_query(client=client, query=query) - aggregation_query.count(alias="total", limit=5) + aggregation_query.count(alias="total") pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) assert len(pb.aggregations) == 1 - assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb() + assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() def test_pb_over_query_with_add_aggregation(client): @@ -69,11 +67,11 @@ def test_pb_over_query_with_add_aggregation(client): query = _make_query(client) aggregation_query = _make_aggregation_query(client=client, query=query) - aggregation_query.add_aggregation(CountAggregation(alias="total", limit=5)) + aggregation_query.add_aggregation(CountAggregation(alias="total")) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) assert len(pb.aggregations) == 1 - assert pb.aggregations[0] == CountAggregation(alias="total", limit=5)._to_pb() + assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() def test_pb_over_query_with_add_aggregations(client): @@ -82,7 +80,6 @@ def test_pb_over_query_with_add_aggregations(client): aggregations = [ CountAggregation(alias="total"), CountAggregation(alias="all"), - CountAggregation(limit=5), ] query = _make_query(client) @@ -91,10 +88,9 @@ def test_pb_over_query_with_add_aggregations(client): aggregation_query.add_aggregations(aggregations) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) - assert len(pb.aggregations) == 3 + assert len(pb.aggregations) == 2 assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() assert pb.aggregations[1] == CountAggregation(alias="all")._to_pb() - assert pb.aggregations[2] == CountAggregation(limit=5)._to_pb() def test_query_fetch_defaults_w_client_attr(client): From c9430e0e72879653b95dcd92b0b7421411cdcada Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Tue, 15 Nov 2022 13:57:18 -0800 Subject: [PATCH 7/7] Add test coverage --- google/cloud/datastore/aggregation.py | 2 +- tests/unit/test_aggregation.py | 31 ++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index 6d68246c..24d2abcc 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -205,7 +205,7 @@ def fetch( >>> client.put_multi([andy, sally, bobby]) >>> query = client.query(kind='Andy') >>> aggregation_query = client.aggregation_query(query) - >>> result = aggregation_query.count(alias="total").fetch(fetch=5) + >>> result = aggregation_query.count(alias="total").fetch(limit=5) >>> result diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index 8b28a908..afa9dc53 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -127,6 +127,22 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(client): assert iterator._timeout == timeout +def test_query_fetch_w_explicit_client_w_limit(client): + from google.cloud.datastore.aggregation import AggregationResultIterator + + other_client = _make_client() + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + limit = 2 + + iterator = aggregation_query.fetch(client=other_client, limit=limit) + + assert isinstance(iterator, AggregationResultIterator) + assert iterator._aggregation_query is aggregation_query + assert iterator.client is other_client + assert iterator._limit == limit + + def test_iterator_constructor_defaults(): query = object() client = object() @@ -149,12 +165,10 @@ def test_iterator_constructor_explicit(): aggregation_query = AggregationQuery(client=client, query=query) retry = mock.Mock() timeout = 100000 + limit = 2 iterator = _make_aggregation_iterator( - aggregation_query, - client, - retry=retry, - timeout=timeout, + aggregation_query, client, retry=retry, timeout=timeout, limit=limit ) assert not iterator._started @@ -165,6 +179,7 @@ def test_iterator_constructor_explicit(): assert iterator._more_results assert iterator._retry == retry assert iterator._timeout == timeout + assert iterator._limit == limit def test_iterator__build_protobuf_empty(): @@ -186,14 +201,20 @@ def test_iterator__build_protobuf_all_values(): client = _Client(None) query = _make_query(client) + alias = "total" + limit = 2 aggregation_query = AggregationQuery(client=client, query=query) + aggregation_query.count(alias) - iterator = _make_aggregation_iterator(aggregation_query, client) + iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit) iterator.num_results = 4 pb = iterator._build_protobuf() expected_pb = query_pb2.AggregationQuery() expected_pb.nested_query = query_pb2.Query() + expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias) + expected_count_pb.count.up_to = limit + expected_pb.aggregations.append(expected_count_pb) assert pb == expected_pb