Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call process_set._setup in init() to point to the correct native lib path #3258

Merged
merged 6 commits into from Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- fix the example of pytorch_lightning_mnist.py ([#3245](https://github.com/horovod/horovod/pull/3245))

- Call _setup in remote trainers to point to the correct shared lib path ([#3258](https://github.com/horovod/horovod/pull/3258))
## [v0.23.0] - 2021-10-06

### Added
Expand Down
6 changes: 5 additions & 1 deletion horovod/mxnet/mpi_ops.py
Expand Up @@ -39,7 +39,6 @@
check_installed_version('mxnet', mx.__version__)

# import basic methods
init = _basics.init
shutdown = _basics.shutdown
is_initialized = _basics.is_initialized
start_timeline = _basics.start_timeline
Expand All @@ -61,6 +60,11 @@
cuda_built = _basics.cuda_built
rocm_built = _basics.rocm_built

def init(*args, **kwargs):
_basics.init(*args, **kwargs)
# Call set up again to make sure the basics is in sync
_setup_process_sets(_basics)

dll_path = os.path.join(os.path.dirname(__file__),
'mpi_lib' + get_ext_suffix())
MPI_MXNET_LIB_CTYPES = ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)
Expand Down
4 changes: 4 additions & 0 deletions horovod/spark/keras/remote.py
Expand Up @@ -109,6 +109,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):

hvd = get_horovod()
hvd.init()

pin_gpu(hvd, tf, k)

if not user_shuffle_buffer_size:
Expand All @@ -129,6 +130,9 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
# Verbose mode 1 will print a progress bar
verbose = user_verbose if hvd.rank() == 0 else 0

if verbose:
print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

transform_spec = None
if transformation:
transform_spec = TransformSpec(transformation)
Expand Down
6 changes: 6 additions & 0 deletions horovod/spark/lightning/remote.py
Expand Up @@ -97,8 +97,14 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro

def train(serialized_model):
import horovod.torch as hvd

# Horovod: initialize library.
hvd.init()

if verbose:
import horovod as _horovod
print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

_checkpoint_callback = None
require_checkpoint = False

Expand Down
4 changes: 4 additions & 0 deletions horovod/spark/torch/remote.py
Expand Up @@ -117,6 +117,10 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized,
# Horovod: initialize library.
hvd.init()

if user_verbose:
import horovod as _horovod
print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")

if not user_shuffle_buffer_size:
shuffle_buffer_size = \
calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size())
Expand Down
6 changes: 5 additions & 1 deletion horovod/tensorflow/mpi_ops.py
Expand Up @@ -57,7 +57,6 @@ def _load_library(name):
_basics = _HorovodBasics(__file__, 'mpi_lib')

# import basic methods
init = _basics.init
shutdown = _basics.shutdown
is_initialized = _basics.is_initialized
start_timeline = _basics.start_timeline
Expand All @@ -84,6 +83,11 @@ def _load_library(name):
Sum = _basics.Sum
Adasum = _basics.Adasum

def init(*args, **kwargs):
_basics.init(*args, **kwargs)
# Call set up again to make sure the basics is in sync
_setup_process_sets(_basics)

is_homogeneous = _basics.is_homogeneous

handle_average_backwards_compatibility = get_average_backwards_compatibility_fun(_basics)
Expand Down
4 changes: 3 additions & 1 deletion horovod/torch/mpi_ops.py
Expand Up @@ -69,7 +69,9 @@ def shutdown(*args, **kwargs):
def init(*args, **kwargs):
global _handle_map
_handle_map = {}
return _basics.init(*args, **kwargs)
_basics.init(*args, **kwargs)
# Call set up again to make sure the basics is in sync
_setup_process_sets(_basics)

# import reduction op values
Average = _basics.Average
Expand Down