Skip to content

Commit

Permalink
call setup for common process_set in remote trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ committed Nov 3, 2021
1 parent 660f7ff commit 3681500
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
6 changes: 6 additions & 0 deletions horovod/spark/keras/remote.py
Expand Up @@ -227,6 +227,12 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
reader_factory = make_batch_reader
is_batch_reader = True

# Call _setup again in process set module to point shared lib to tensorflow's module
# since the lib path might be overwritten in remote trainer.
_horovod.common.process_sets._setup(_horovod.tensorflow.mpi_ops._basics)
if verbose:
print(f"Set shared lib path to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

with reader_factory(remote_store.train_data_path,
num_epochs=1,
cur_shard=hvd.rank(),
Expand Down
8 changes: 8 additions & 0 deletions horovod/spark/lightning/remote.py
Expand Up @@ -97,6 +97,8 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro

def train(serialized_model):
import horovod.torch as hvd
import horovod as _horovod

# Horovod: initialize library.
hvd.init()
_checkpoint_callback = None
Expand Down Expand Up @@ -216,6 +218,12 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -

print(f"pytorch_lightning version={pl.__version__}")

# Call _setup again in process set module to point shared lib to torch's module
# since the lib path might be overwritten in remote trainer.
_horovod.common.process_sets._setup(_horovod.torch.mpi_ops._basics)
if verbose:
print(f"Set shared lib path to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

dataset = data_module(train_dir=remote_store.train_data_path,
val_dir=remote_store.val_data_path,
num_train_epochs=epochs,
Expand Down
7 changes: 7 additions & 0 deletions horovod/spark/torch/remote.py
Expand Up @@ -103,6 +103,7 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized,
from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader
import torch
import horovod.torch as hvd
import horovod as _horovod

# Deserializing objects
model_opt_state = torch.load(model_opt_state_serialized)
Expand Down Expand Up @@ -227,6 +228,12 @@ def save_checkpoint():
else:
reader_factory = make_batch_reader

# Call _setup again in process set module to point shared lib to torch's module
# since the lib path might be overwritten in remote trainer.
_horovod.common.process_sets._setup(_horovod.torch.mpi_ops._basics)
if user_verbose:
print(f"Set shared lib path to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

# Petastorm: read data from the store with the correct shard for this rank
# setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
Expand Down

0 comments on commit 3681500

Please sign in to comment.