Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data_loader #3239

Merged
merged 10 commits into from Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 10 additions & 7 deletions horovod/data/data_loader_base.py
Expand Up @@ -55,7 +55,7 @@ class AsyncDataLoaderMixin(object):
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
"""

def __init__(self, async_loader_queue_size=64, *args, **kwargs):
def __init__(self, async_loader_queue_size=1, *args, **kwargs):
"""
initialize the async data loader. Need to add this in the __init__() of the implementation
"""
Expand Down Expand Up @@ -92,12 +92,12 @@ def _async_worker(self):
User need to implement self._iterate() to read the data.
"""
try:
# Only need to iterate once because data loader will be re-created in each epoch.
for batch in self._iterate():
if self.finished_event.is_set():
break
self.queue.put(batch)
self.queue.put(None)
while not self.finished_event.is_set():
for batch in self._iterate():
if self.finished_event.is_set():
break
self.queue.put(batch)
self.queue.put(None)
except Exception as ex:
self.queue.put(ex)
self.queue.put(None)
Expand Down Expand Up @@ -125,3 +125,6 @@ def __iter__(self):
else:
for batch in self._iterate():
yield self._process_batch(batch)

def __del__(self):
self.close_async_loader()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a place del datalaoder is called?
And in python garbage collection is delayed, it is preferred to call close_async_loader() explicitly instead of putting in garbage collection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else do you think we can put this close_async_loader() if it is not in del? Is there a close/ending function in lightning module?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, removed it.

6 changes: 6 additions & 0 deletions horovod/spark/lightning/datamodule.py
Expand Up @@ -94,6 +94,9 @@ def train_dataloader(self):
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size
# To avoid loading too much data in memory, need to limit the queue size so it will
# only load 1/4 of the data or less than the 10000 rows. (batch_size * queue_size)
kwargs['async_loader_queue_size'] = max(1, min(10000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which case we see OOM issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suspecting the hanging issue in deepETA all data.


self.train_dl = dataloader_class(**kwargs)
return self.train_dl
Expand All @@ -115,6 +118,9 @@ def val_dataloader(self):
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = 0
# To avoid loading too much data in memory, need to limit the queue size so it will
# only load 1/4 of the data or less than the 10000 rows. (batch_size * queue_size)
kwargs['async_loader_queue_size'] = max(1, min(10000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4))

self.val_dl = dataloader_class(**kwargs)
return self.val_dl