diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a57607722662..29122e7ca7a1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249)) +- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606)) + + - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 5b8012468dfef..e523bf7a350a1 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -201,19 +201,15 @@ def _no_op_batch_to_device(batch: Any) -> Any: class DataFetcher(AbstractDataFetcher): - - """This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a - batch in advance to detect the end of the iteration. + """This class is used to control batch fetching flow. Args: - prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch - at least 1 batch for tracking the latest batch. + prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track + whether a batch is the last one (available with :attr:`self.done`). store_on_device: Whether to store the pre-fetched batches on device. """ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None: - if prefetch_batches < 1: - raise MisconfigurationException("`prefetch_batches` should at least be 1.") super().__init__(prefetch_batches=prefetch_batches) self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device @@ -240,19 +236,31 @@ def prefetching(self) -> None: break def fetching_function(self) -> Tuple[Any, bool]: + assert self.dataloader_iter is not None if self.batches: + # there are pre-fetched batches already from a previous `prefetching` call. + # consume one batch = self.batches.pop(0) - else: - # empty iterator, no prefetching done - raise StopIteration - if not self.done: - assert self.dataloader_iter is not None try: + # refill the consumed batch self._fetch_next_batch(self.dataloader_iter) except StopIteration: + # no more batches to fetch. we are done only if all pre-fetched batches were returned + self.done = not self.batches + elif not self.done: + # this will run only when no pre-fetching was done. + try: + self._fetch_next_batch(self.dataloader_iter) + # consume the batch we just fetched + batch = self.batches.pop(0) + except StopIteration as e: self.done = True + raise e + else: + # the iterator is empty + raise StopIteration self.wait() - return self.move_to_device(batch), len(self.batches) == 0 + return self.move_to_device(batch), self.done def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 0d62f35ee79c7..aa01c87041532 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -40,28 +40,32 @@ def __iter__(self): yield 2 yield 3 - for prefetch_batches in range(1, 5): + for prefetch_batches in range(5): + iterator = DataFetcher(prefetch_batches=prefetch_batches) + assert iterator.prefetch_batches == prefetch_batches + if use_combined_loader: loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) - expected = [ - ([tensor([1]), tensor([1])], False), - ([tensor([2]), tensor([2])], False), - ([tensor([3]), tensor([3])], True), - ] else: loader = DataLoader(IterDataset()) - expected = [(1, False), (2, False), (3, True)] - iterator = DataFetcher(prefetch_batches=prefetch_batches) - assert iterator.prefetch_batches == prefetch_batches iterator.setup(loader) def generate(): - generated = [] - for idx, data in enumerate(iterator, prefetch_batches + 1): - assert iterator.fetched == 3 if iterator.done else idx - generated.append(data) + generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)] + assert iterator.fetched == 3 + assert iterator.done return generated + is_last_batch = [False, False, prefetch_batches > 0] + fetched = list(range(prefetch_batches + 1, 4)) + fetched += [3] * (3 - len(fetched)) + if use_combined_loader: + batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] + else: + batches = [1, 2, 3] + expected = list(zip(fetched, batches, is_last_batch)) + assert len(expected) == 3 + assert generate() == expected # validate reset works properly. assert generate() == expected @@ -71,9 +75,9 @@ class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) - dataloader = DataLoader(EmptyIterDataset()) + loader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() - iterator.setup(dataloader) + iterator.setup(loader) assert not list(iterator)