From cbbd56541e6823c797694088c8e8a5a7babfff46 Mon Sep 17 00:00:00 2001 From: Olivier Bellone Date: Fri, 11 Oct 2019 17:11:46 -0700 Subject: [PATCH] Add support for reverse pagination --- stripe/api_resources/list_object.py | 131 ++++++++++++++---- tests/api_resources/test_list_object.py | 177 ++++++++++++++++++++++-- 2 files changed, 274 insertions(+), 34 deletions(-) diff --git a/stripe/api_resources/list_object.py b/stripe/api_resources/list_object.py index 7efb1b130..af2af033f 100644 --- a/stripe/api_resources/list_object.py +++ b/stripe/api_resources/list_object.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, division, print_function -from stripe import api_requestor, util +from stripe import api_requestor, six, util from stripe.stripe_object import StripeObject from stripe.six.moves.urllib.parse import quote_plus @@ -23,22 +23,6 @@ def list( stripe_object._retrieve_params = params return stripe_object - def auto_paging_iter(self): - page = self - params = dict(self._retrieve_params) - - while True: - item_id = None - for item in page: - item_id = item.get("id", None) - yield item - - if not getattr(page, "has_more", False) or item_id is None: - return - - params["starting_after"] = item_id - page = self.list(**params) - def create( self, api_key=None, @@ -75,16 +59,10 @@ def retrieve( **params ) - def __iter__(self): - return getattr(self, "data", []).__iter__() - - def __len__(self): - return getattr(self, "data", []).__len__() - def _request( self, - method, - url, + method_, + url_, api_key=None, idempotency_key=None, stripe_version=None, @@ -99,8 +77,109 @@ def _request( api_key, api_version=stripe_version, account=stripe_account ) headers = util.populate_headers(idempotency_key) - response, api_key = requestor.request(method, url, params, headers) + response, api_key = requestor.request(method_, url_, params, headers) stripe_object = util.convert_to_stripe_object( response, api_key, stripe_version, stripe_account ) return stripe_object + + def __getitem__(self, k): + if isinstance(k, six.string_types): + return super(ListObject, self).__getitem__(k) + else: + raise KeyError( + "You tried to access the %s index, but ListObject types only " + "support string keys. (HINT: List calls return an object with " + "a 'data' (which is the data array). You likely want to call " + ".data[%s])" % (repr(k), repr(k)) + ) + + def __iter__(self): + return getattr(self, "data", []).__iter__() + + def __len__(self): + return getattr(self, "data", []).__len__() + + def __reversed__(self): + return getattr(self, "data", []).__reversed__() + + def auto_paging_iter(self): + page = self + + while True: + if ( + "ending_before" in self._retrieve_params + and "starting_after" not in self._retrieve_params + ): + for item in reversed(page): + yield item + page = page.previous_page() + else: + for item in page: + yield item + page = page.next_page() + + if page.is_empty: + break + + @classmethod + def empty_list( + cls, api_key=None, stripe_version=None, stripe_account=None + ): + return cls.construct_from( + {"data": []}, + key=api_key, + stripe_version=stripe_version, + stripe_account=stripe_account, + last_response=None, + ) + + @property + def is_empty(self): + return not self.data + + def next_page( + self, api_key=None, stripe_version=None, stripe_account=None, **params + ): + if not self.has_more: + return self.empty_list( + api_key=api_key, + stripe_version=stripe_version, + stripe_account=stripe_account, + ) + + last_id = self.data[-1].id + + params_with_filters = self._retrieve_params.copy() + params_with_filters.update({"starting_after": last_id}) + params_with_filters.update(params) + + return self.list( + api_key=api_key, + stripe_version=stripe_version, + stripe_account=stripe_account, + **params_with_filters + ) + + def previous_page( + self, api_key=None, stripe_version=None, stripe_account=None, **params + ): + if not self.has_more: + return self.empty_list( + api_key=api_key, + stripe_version=stripe_version, + stripe_account=stripe_account, + ) + + first_id = self.data[0].id + + params_with_filters = self._retrieve_params.copy() + params_with_filters.update({"ending_before": first_id}) + params_with_filters.update(params) + + return self.list( + api_key=api_key, + stripe_version=stripe_version, + stripe_account=stripe_account, + **params_with_filters + ) diff --git a/tests/api_resources/test_list_object.py b/tests/api_resources/test_list_object.py index c74478696..ce5e1bb87 100644 --- a/tests/api_resources/test_list_object.py +++ b/tests/api_resources/test_list_object.py @@ -14,14 +14,6 @@ def list_object(self): {"object": "list", "url": "/my/path", "data": ["foo"]}, "mykey" ) - def test_for_loop(self, list_object): - seen = [] - - for item in list_object: - seen.append(item) - - assert seen == ["foo"] - def test_list(self, request_mock, list_object): request_mock.stub_request( "get", @@ -87,6 +79,26 @@ def test_retrieve(self, request_mock, list_object): assert res.foo == "bar" assert res.stripe_account == "acct_123" + def test_is_empty(self): + lo = stripe.ListObject.construct_from({"data": []}, None) + assert lo.is_empty is True + + def test_empty_list(self): + lo = stripe.ListObject.empty_list() + assert lo.is_empty + + def test_iter(self): + arr = [{"id": 1}, {"id": 2}, {"id": 3}] + expected = stripe.util.convert_to_stripe_object(arr) + lo = stripe.ListObject.construct_from({"data": arr}, None) + assert list(lo) == expected + + def test_iter_reversed(self): + arr = [{"id": 1}, {"id": 2}, {"id": 3}] + expected = stripe.util.convert_to_stripe_object(list(reversed(arr))) + lo = stripe.ListObject.construct_from({"data": arr}, None) + assert list(reversed(lo)) == expected + def test_len(self, list_object): assert len(list_object) == 1 @@ -98,6 +110,132 @@ def test_bool(self, list_object): ) assert bool(empty) is False + def test_next_page(self, request_mock): + lo = stripe.ListObject.construct_from( + { + "object": "list", + "data": [{"id": 1}], + "has_more": True, + "url": "/things", + }, + None, + ) + + request_mock.stub_request( + "get", + "/things", + { + "object": "list", + "data": [{"id": 2}], + "has_more": False, + "url": "/things", + }, + ) + + next_lo = lo.next_page() + assert not next_lo.is_empty + assert next_lo.data[0].id == 2 + + def test_next_page_with_filters(self, request_mock): + lo = stripe.ListObject.construct_from( + { + "object": "list", + "data": [{"id": 1}], + "has_more": True, + "url": "/things", + }, + None, + ) + lo._retrieve_params = {"expand": ["data.source"], "limit": 3} + + request_mock.stub_request( + "get", + "/things", + { + "object": "list", + "data": [{"id": 2}], + "has_more": False, + "url": "/things", + }, + ) + + next_lo = lo.next_page() + assert next_lo._retrieve_params == { + "expand": ["data.source"], + "limit": 3, + "starting_after": 1, + } + + def test_next_page_empty_list(self): + lo = stripe.ListObject.construct_from( + { + "object": "list", + "data": [{"id": 1}], + "has_more": False, + "url": "/things", + }, + None, + ) + + next_lo = lo.next_page() + assert next_lo == stripe.ListObject.empty_list() + + def test_prev_page(self, request_mock): + lo = stripe.ListObject.construct_from( + { + "object": "list", + "data": [{"id": 2}], + "has_more": True, + "url": "/things", + }, + None, + ) + + request_mock.stub_request( + "get", + "/things", + { + "object": "list", + "data": [{"id": 1}], + "has_more": False, + "url": "/things", + }, + ) + + previous_lo = lo.previous_page() + assert not previous_lo.is_empty + assert previous_lo.data[0].id == 1 + + def test_prev_page_with_filters(self, request_mock): + lo = stripe.ListObject.construct_from( + { + "object": "list", + "data": [{"id": 2}], + "has_more": True, + "url": "/things", + }, + None, + ) + lo._retrieve_params = {"expand": ["data.source"], "limit": 3} + + request_mock.stub_request( + "get", + "/things", + { + "object": "list", + "data": [{"id": 1}], + "has_more": False, + "url": "/things", + }, + ) + + previous_lo = lo.previous_page() + assert previous_lo._retrieve_params == { + "expand": ["data.source"], + "limit": 3, + "ending_before": 2, + } + def test_serialize_empty_list(self): empty = stripe.ListObject.construct_from( {"object": "list", "data": []}, "mykey" @@ -166,6 +304,29 @@ def test_iter_two_pages(self, request_mock): assert seen == ["pm_123", "pm_124", "pm_125", "pm_126"] + def test_iter_reverse(self, request_mock): + lo = stripe.ListObject.construct_from( + self.pageable_model_response(["pm_125", "pm_126"], True), "mykey" + ) + lo._retrieve_params = {"foo": "bar", "ending_before": "pm_127"} + + request_mock.stub_request( + "get", + "/v1/pageablemodels", + self.pageable_model_response(["pm_123", "pm_124"], False), + ) + + seen = [item["id"] for item in lo.auto_paging_iter()] + + request_mock.assert_requested( + "get", + "/v1/pageablemodels", + {"ending_before": "pm_125", "foo": "bar"}, + None, + ) + + assert seen == ["pm_126", "pm_125", "pm_124", "pm_123"] + def test_class_method_two_pages(self, request_mock): request_mock.stub_request( "get",