Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support no pre-fetching #11606

Merged
merged 10 commits into from Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -66,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))


- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606))


- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))


Expand Down
34 changes: 21 additions & 13 deletions pytorch_lightning/utilities/fetching.py
Expand Up @@ -201,19 +201,15 @@ def _no_op_batch_to_device(batch: Any) -> Any:


class DataFetcher(AbstractDataFetcher):

"""This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a
batch in advance to detect the end of the iteration.
"""This class is used to control batch fetching flow.

Args:
prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch
at least 1 batch for tracking the latest batch.
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`).
store_on_device: Whether to store the pre-fetched batches on device.
"""

def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None:
if prefetch_batches < 1:
raise MisconfigurationException("`prefetch_batches` should at least be 1.")
super().__init__(prefetch_batches=prefetch_batches)
self.store_on_device = store_on_device
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
Expand All @@ -240,19 +236,31 @@ def prefetching(self) -> None:
break

def fetching_function(self) -> Tuple[Any, bool]:
assert self.dataloader_iter is not None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if self.batches:
# there are pre-fetched batches already from a previous `prefetching` call.
# consume one
batch = self.batches.pop(0)
else:
# empty iterator, no prefetching done
raise StopIteration
if not self.done:
assert self.dataloader_iter is not None
try:
# refill the consumed batch
self._fetch_next_batch(self.dataloader_iter)
except StopIteration:
# no more batches to fetch. we are done only if all pre-fetched batches were returned
self.done = not self.batches
elif not self.done:
# this will run only when no pre-fetching was done.
try:
self._fetch_next_batch(self.dataloader_iter)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# consume the batch we just fetched
batch = self.batches.pop(0)
except StopIteration as e:
self.done = True
raise e
else:
# the iterator is empty
raise StopIteration
self.wait()
return self.move_to_device(batch), len(self.batches) == 0
return self.move_to_device(batch), self.done

def _fetch_next_batch(self, iterator: Iterator) -> None:
start_output = self.on_fetch_start()
Expand Down
34 changes: 19 additions & 15 deletions tests/utilities/test_fetching.py
Expand Up @@ -40,28 +40,32 @@ def __iter__(self):
yield 2
yield 3

for prefetch_batches in range(1, 5):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
for prefetch_batches in range(5):
iterator = DataFetcher(prefetch_batches=prefetch_batches)
assert iterator.prefetch_batches == prefetch_batches

if use_combined_loader:
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
expected = [
([tensor([1]), tensor([1])], False),
([tensor([2]), tensor([2])], False),
([tensor([3]), tensor([3])], True),
]
else:
loader = DataLoader(IterDataset())
expected = [(1, False), (2, False), (3, True)]
iterator = DataFetcher(prefetch_batches=prefetch_batches)
assert iterator.prefetch_batches == prefetch_batches
iterator.setup(loader)

def generate():
generated = []
for idx, data in enumerate(iterator, prefetch_batches + 1):
assert iterator.fetched == 3 if iterator.done else idx
generated.append(data)
generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)]
assert iterator.fetched == 3
assert iterator.done
return generated

is_last_batch = [False, False, prefetch_batches > 0]
fetched = list(range(prefetch_batches + 1, 4))
fetched += [3] * (3 - len(fetched))
if use_combined_loader:
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
else:
batches = [1, 2, 3]
expected = list(zip(fetched, batches, is_last_batch))
assert len(expected) == 3

assert generate() == expected
# validate reset works properly.
assert generate() == expected
Expand All @@ -71,9 +75,9 @@ class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])

dataloader = DataLoader(EmptyIterDataset())
loader = DataLoader(EmptyIterDataset())
iterator = DataFetcher()
iterator.setup(dataloader)
iterator.setup(loader)
assert not list(iterator)


Expand Down