From 4e06ce78a5752979d49a89b3ab313c339641c8ed Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 03:23:44 +0100 Subject: [PATCH 1/8] Add support for no prefetching --- pytorch_lightning/utilities/fetching.py | 22 ++++++++++++++++------ tests/utilities/test_fetching.py | 19 ++++++++++--------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index ccbbce79829de..fd1c591d5e076 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -237,18 +237,28 @@ def prefetching(self) -> None: def fetching_function(self) -> Tuple[Any, bool]: if self.batches: + # we pre-fetched, consume one batch = self.batches.pop(0) - else: - # empty iterator, no prefetching done - raise StopIteration - if not self.done: - assert self.dataloader_iter is not None try: + # and replace it self._fetch_next_batch(self.dataloader_iter) except StopIteration: + self.done = not self.batches + + elif not self.done: + # we did not prefetch at all + try: + self._fetch_next_batch(self.dataloader_iter) + batch = self.batches.pop(0) + except StopIteration as e: self.done = True + raise e + + else: + raise StopIteration + self.wait() - return self.move_to_device(batch), len(self.batches) == 0 + return self.move_to_device(batch), self.done def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 0d62f35ee79c7..6e1ab2a1a26d2 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -40,19 +40,14 @@ def __iter__(self): yield 2 yield 3 - for prefetch_batches in range(1, 5): + for prefetch_batches in range(5): + iterator = DataFetcher(prefetch_batches=prefetch_batches) + assert iterator.prefetch_batches == prefetch_batches + if use_combined_loader: loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) - expected = [ - ([tensor([1]), tensor([1])], False), - ([tensor([2]), tensor([2])], False), - ([tensor([3]), tensor([3])], True), - ] else: loader = DataLoader(IterDataset()) - expected = [(1, False), (2, False), (3, True)] - iterator = DataFetcher(prefetch_batches=prefetch_batches) - assert iterator.prefetch_batches == prefetch_batches iterator.setup(loader) def generate(): @@ -62,6 +57,12 @@ def generate(): generated.append(data) return generated + is_last_batch = [False, False, prefetch_batches > 0] + if use_combined_loader: + batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] + else: + batches = [1, 2, 3] + expected = list(zip(batches, is_last_batch)) assert generate() == expected # validate reset works properly. assert generate() == expected From 4f4f9ac22b7e8b2d9006f2134e17cfafaf973540 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 04:02:14 +0100 Subject: [PATCH 2/8] Refactor test --- tests/utilities/test_fetching.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 6e1ab2a1a26d2..aa01c87041532 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -51,18 +51,21 @@ def __iter__(self): iterator.setup(loader) def generate(): - generated = [] - for idx, data in enumerate(iterator, prefetch_batches + 1): - assert iterator.fetched == 3 if iterator.done else idx - generated.append(data) + generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)] + assert iterator.fetched == 3 + assert iterator.done return generated is_last_batch = [False, False, prefetch_batches > 0] + fetched = list(range(prefetch_batches + 1, 4)) + fetched += [3] * (3 - len(fetched)) if use_combined_loader: batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] else: batches = [1, 2, 3] - expected = list(zip(batches, is_last_batch)) + expected = list(zip(fetched, batches, is_last_batch)) + assert len(expected) == 3 + assert generate() == expected # validate reset works properly. assert generate() == expected @@ -72,9 +75,9 @@ class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) - dataloader = DataLoader(EmptyIterDataset()) + loader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() - iterator.setup(dataloader) + iterator.setup(loader) assert not list(iterator) From 9c0b903281de42677a8dfe43516121dea6b87481 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 04:15:36 +0100 Subject: [PATCH 3/8] Update docs and CHANGELOG --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/fetching.py | 8 +++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..7d31b0ec72b2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249)) +- Added support for no pre-fetching to `class DataFetcher(AbstractDataFetcher)` + + - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index fd1c591d5e076..9938cc30ee2ab 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -203,13 +203,11 @@ def teardown(self) -> None: class DataFetcher(AbstractDataFetcher): - - """This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a - batch in advance to detect the end of the iteration. + """This class is used to control batch fetching flow. Args: - prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch - at least 1 batch for tracking the latest batch. + prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessasry to properly track + whether a batch is the last one (available with :attr:`self.done`). store_on_device: Whether to store the pre-fetched batches on device. """ From 53ef47cb57155b55a66038e45709843b72ed6e5e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 04:16:15 +0100 Subject: [PATCH 4/8] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d31b0ec72b2d..cb7df769821b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249)) -- Added support for no pre-fetching to `class DataFetcher(AbstractDataFetcher)` +- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606)) - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) From 9f5ad532c56a7542bebb983e53478581542d262f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 Jan 2022 04:17:57 +0100 Subject: [PATCH 5/8] Mypy --- pytorch_lightning/utilities/fetching.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 9938cc30ee2ab..357e30d8c0b67 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -234,6 +234,7 @@ def prefetching(self) -> None: break def fetching_function(self) -> Tuple[Any, bool]: + assert self.dataloader_iter is not None if self.batches: # we pre-fetched, consume one batch = self.batches.pop(0) @@ -242,7 +243,6 @@ def fetching_function(self) -> Tuple[Any, bool]: self._fetch_next_batch(self.dataloader_iter) except StopIteration: self.done = not self.batches - elif not self.done: # we did not prefetch at all try: @@ -251,10 +251,8 @@ def fetching_function(self) -> Tuple[Any, bool]: except StopIteration as e: self.done = True raise e - else: raise StopIteration - self.wait() return self.move_to_device(batch), self.done From f83fff0e6e325121200ba6f879f1660cddd0e1be Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 27 Jan 2022 16:06:47 +0100 Subject: [PATCH 6/8] Add comments --- pytorch_lightning/utilities/fetching.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 357e30d8c0b67..8e18bc3e85187 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -236,22 +236,26 @@ def prefetching(self) -> None: def fetching_function(self) -> Tuple[Any, bool]: assert self.dataloader_iter is not None if self.batches: - # we pre-fetched, consume one + # there are pre-fetched batches already from a previous `prefetching` call. + # consume one batch = self.batches.pop(0) try: - # and replace it + # refill the consumed batch self._fetch_next_batch(self.dataloader_iter) except StopIteration: + # no more batches to fetch. we are done only if all pre-fetched batches were returned self.done = not self.batches elif not self.done: - # we did not prefetch at all + # this will run only when no pre-fetching was done. try: self._fetch_next_batch(self.dataloader_iter) + # consume the batch we just fetched batch = self.batches.pop(0) except StopIteration as e: self.done = True raise e else: + # the iterator is empty raise StopIteration self.wait() return self.move_to_device(batch), self.done From c4d3b6c219d6fc62f1cce8a80bff31008bc41d57 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 27 Jan 2022 16:07:47 +0100 Subject: [PATCH 7/8] Typo --- pytorch_lightning/utilities/fetching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 8e18bc3e85187..39d3268a751b1 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -206,7 +206,7 @@ class DataFetcher(AbstractDataFetcher): """This class is used to control batch fetching flow. Args: - prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessasry to properly track + prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track whether a batch is the last one (available with :attr:`self.done`). store_on_device: Whether to store the pre-fetched batches on device. """ From 8a0942b6305da1cb7c66cdfc45e4017550328841 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Feb 2022 22:03:52 +0100 Subject: [PATCH 8/8] Remove check --- pytorch_lightning/utilities/fetching.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index fb930dc7a44da..e523bf7a350a1 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -210,8 +210,6 @@ class DataFetcher(AbstractDataFetcher): """ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None: - if prefetch_batches < 1: - raise MisconfigurationException("`prefetch_batches` should at least be 1.") super().__init__(prefetch_batches=prefetch_batches) self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device