Skip to content

Commit

Permalink
Spark/Lightning: reduce memory footprint of async dataloader (horovod…
Browse files Browse the repository at this point in the history
…#3239)

Limit async data loader queue size.

Signed-off-by: Peng Zhang <pengz@uber.com>
Signed-off-by: weihanmines <weihan13@amd.com>
  • Loading branch information
irasit authored and weihanmines committed Dec 10, 2021
1 parent 0e61f3c commit e0d794e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
3 changes: 2 additions & 1 deletion horovod/data/data_loader_base.py
Expand Up @@ -55,7 +55,7 @@ class AsyncDataLoaderMixin(object):
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
"""

def __init__(self, async_loader_queue_size=64, *args, **kwargs):
def __init__(self, async_loader_queue_size=5, *args, **kwargs):
"""
initialize the async data loader. Need to add this in the __init__() of the implementation
"""
Expand Down Expand Up @@ -115,6 +115,7 @@ def __iter__(self):
if not self.started:
self.started = True
self.thread.start()

while True:
batch = self.queue.get()
if batch is None:
Expand Down
10 changes: 10 additions & 0 deletions horovod/spark/lightning/datamodule.py
Expand Up @@ -78,6 +78,8 @@ def teardown(self, stage=None):
if self.has_val:
self.val_reader.stop()
self.val_reader.join()
if self.verbose:
print("Tear down: async dataloaders closed.")

def train_dataloader(self):
if self.verbose:
Expand All @@ -94,6 +96,10 @@ def train_dataloader(self):
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size
# To avoid loading too much data in memory, need to calculate the queue size
# dynamicaly, and limit the data loaded in queue.
# Add 1 in size for storing the None in the end of each epoch.
kwargs['async_loader_queue_size'] = max(1, min(100000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1

self.train_dl = dataloader_class(**kwargs)
return self.train_dl
Expand All @@ -115,6 +121,10 @@ def val_dataloader(self):
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = 0
# To avoid loading too much data in memory, need to calculate the queue size
# dynamicaly, and limit the data loaded in queue.
# Add 1 in size for storing the None in the end of each epoch.
kwargs['async_loader_queue_size'] = max(1, min(10000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1

self.val_dl = dataloader_class(**kwargs)
return self.val_dl
2 changes: 0 additions & 2 deletions horovod/spark/lightning/remote.py
Expand Up @@ -193,8 +193,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'gpus': _num_gpus,
'callbacks': callbacks,
'max_epochs': epochs,
'limit_train_batches': _train_steps_per_epoch,
'limit_val_batches': _val_steps_per_epoch,
'logger': train_logger,
'log_every_n_steps': log_every_n_steps,
'num_sanity_val_steps': 0,
Expand Down

0 comments on commit e0d794e

Please sign in to comment.