Skip to content

Commit

Permalink
RayExecutorV2: Dynamic executor for elastic and static jobs
Browse files Browse the repository at this point in the history
This resolves horovod#3190 with a new RayExecutor API for horovod:
`RayExecutorV2`. This API supports both static(non-elastic) and elastic horovod jobs.

Example of static job:
```python
from horovod.ray import RayExecutor
ray.init()
hjob = RayExecutorV2(setting, NoneElasticParams(
        np=num_workers,
        use_gpu=True
    ))

executor.start()

def simple_fn():
    hvd.init()
    print("hvd rank", hvd.rank())
    return hvd.rank()

result = executor.run(simple_fn)
assert len(set(result)) == hosts * num_slots

executor.shutdown()
```
Example of an elastic job:
```
import horovod.torch as hvd

def training_fn():
    hvd.init()
    model = Model()
    torch.cuda.set_device(hvd.local_rank())

    @hvd.elastic.run
    def train(state):
        for state.epoch in range(state.epoch, epochs):
            ...
            state.commit()

    state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0)
    state.register_reset_callbacks([on_state_reset])
    train(state)
    return

executor = RayExecutorV2(settings, ElasticParams(use_gpu=True, cpus_per_worker=2))
executor.start()
executor.run(training_fn)
```
  • Loading branch information
ashahab committed Oct 18, 2021
1 parent aa31114 commit d72be4a
Show file tree
Hide file tree
Showing 6 changed files with 1,771 additions and 11 deletions.
22 changes: 12 additions & 10 deletions docs/ray.rst
Expand Up @@ -5,7 +5,7 @@ Horovod on Ray

``horovod.ray`` allows users to leverage Horovod on `a Ray cluster <https://docs.ray.io/en/latest/cluster/index.html>`_.

Currently, the Ray + Horovod integration provides a :ref:`RayExecutor API <horovod_ray_api>`.
Currently, the Ray + Horovod integration provides a :ref:`RayExecutorV2 API <horovod_ray_api>`.

.. note:: The Ray + Horovod integration currently only supports a Gloo backend.

Expand All @@ -24,25 +24,25 @@ See the Ray documentation for `advanced installation instructions <https://docs.
Horovod Ray Executor
--------------------

The Horovod Ray integration offers a ``RayExecutor`` abstraction (:ref:`docs <horovod_ray_api>`),
The Horovod Ray integration offers a ``RayExecutorV2`` abstraction (:ref:`docs <horovod_ray_api>`),
which is a wrapper over a group of `Ray actors (stateful processes) <https://docs.ray.io/en/latest/walkthrough.html#remote-classes-actors>`_.

.. code-block:: python
from horovod.ray import RayExecutor
from horovod.ray import RayExecutorV2
# Start the Ray cluster or attach to an existing Ray cluster
ray.init()
# Start num_workers actors on the cluster
executor = RayExecutor(
executor = RayExecutorV2(
setting, num_workers=num_workers, use_gpu=True)
# This will launch `num_workers` actors on the Ray Cluster.
executor.start()
All actors will be part of the Horovod ring, so ``RayExecutor`` invocations will be able to support arbitrary Horovod collective operations.
All actors will be part of the Horovod ring, so ``RayExecutorV2`` invocations will be able to support arbitrary Horovod collective operations.

Note that there is an implicit assumption on the cluster being homogenous in shape (i.e., all machines have the same number of slots available). This is simply
an implementation detail and is not a fundamental limitation.
Expand Down Expand Up @@ -74,7 +74,7 @@ A unique feature of Ray is its support for `stateful Actors <https://docs.ray.io
import torch
from horovod.torch import hvd
from horovod.ray import RayExecutor
from horovod.ray import RayExecutorV2, NonElasticParams
class MyModel:
def __init__(self, learning_rate):
Expand All @@ -93,7 +93,7 @@ A unique feature of Ray is its support for `stateful Actors <https://docs.ray.io
ray.init()
executor = RayExecutor(...)
executor = RayExecutorV2(...)
executor.start(executable_cls=MyModel)
# Run 5 training steps
Expand Down Expand Up @@ -153,10 +153,12 @@ You can then attach to the underlying Ray cluster and execute the training funct
.. code-block:: python
import ray
from horovod.ray import RayExecutorV2, ElasticParams
ray.init(address="auto") # attach to the Ray cluster
settings = ElasticRayExecutor.create_settings(verbose=True)
executor = ElasticRayExecutor(
settings, use_gpu=True, cpus_per_slot=2)
settings = RayExecutorV2.create_settings(verbose=True)
executor = RayExecutorV2(
settings, ElasticParams(min_np=1, use_gpu=True, cpus_per_slot=2))
executor.start()
executor.run(training_fn)
Expand Down
3 changes: 2 additions & 1 deletion horovod/ray/__init__.py
@@ -1,5 +1,6 @@
from .worker import BaseHorovodWorker
from .runner import RayExecutor
from .runner_v2 import RayExecutorV2, NonElasticParams, ElasticParams
from .elastic import ElasticRayExecutor

__all__ = ["RayExecutor", "BaseHorovodWorker", "ElasticRayExecutor"]
__all__ = ["RayExecutor", "BaseHorovodWorker", "ElasticRayExecutor", "RayExecutorV2", "NonElasticParams", "ElasticParams"]
311 changes: 311 additions & 0 deletions horovod/ray/elastic_v2.py
@@ -0,0 +1,311 @@
from typing import Dict, Callable, Any, Optional, List
import logging
import ray.exceptions
import socket

import time
import threading

from horovod.runner.http.http_server import RendezvousServer
from horovod.ray.utils import detect_nics
from horovod.runner.elastic.rendezvous import create_rendezvous_handler
from horovod.runner.gloo_run import (create_slot_env_vars, create_run_env_vars,
_get_min_start_hosts)
from horovod.ray.worker import BaseHorovodWorker
from horovod.ray.elastic import RayHostDiscovery
from horovod.runner.elastic.driver import ElasticDriver

logger = logging.getLogger(__name__)

if hasattr(ray.exceptions, "GetTimeoutError"):
GetTimeoutError = ray.exceptions.GetTimeoutError
elif hasattr(ray.exceptions, "RayTimeoutError"):
GetTimeoutError = ray.exceptions.RayTimeoutError
else:
raise ImportError("Unable to find Ray Timeout Error class "
"(GetTimeoutError, RayTimeoutError). "
"This is likely due to the Ray version not "
"compatible with Horovod-Ray.")

class ElasticAdapter:
"""Adapter for executing Ray calls for elastic Horovod jobs."""
def __init__(self,
settings,
min_np: int,
max_np: Optional[int] = None,
use_gpu: bool = False,
cpus_per_worker: int = 1,
gpus_per_worker: Optional[int] = None,
env_vars: dict = None,
override_discovery: bool=True,
reset_limit: int = None,
elastic_timeout: int = 600,
**kwargs: Optional[Dict]):
self.settings = settings
if override_discovery:
settings.discovery = RayHostDiscovery(
use_gpu=use_gpu,
cpus_per_slot=cpus_per_worker,
gpus_per_slot=gpus_per_worker)
self.cpus_per_worker = cpus_per_worker
self.gpus_per_worker = gpus_per_worker
self.use_gpu = use_gpu
# moved from settings
self.min_np = min_np
self.max_np = max_np
self.np = min_np
self.reset_limit = reset_limit
self.elastic_timeout = elastic_timeout
self.driver = None
self.rendezvous = None
self.env_vars = env_vars or {}

def start(self,
executable_cls: type = None,
executable_args: Optional[List] = None,
executable_kwargs: Optional[Dict] = None):

self.rendezvous = RendezvousServer(self.settings.verbose)
self.driver = ElasticDriver(
rendezvous=self.rendezvous,
discovery=self.settings.discovery,
min_np=self.min_np,
max_np=self.max_np,
timeout=self.elastic_timeout,
reset_limit=self.reset_limit,
verbose=self.settings.verbose)
handler = create_rendezvous_handler(self.driver)
logger.debug("[ray] starting rendezvous")
global_rendezv_port = self.rendezvous.start(handler)

logger.debug(f"[ray] waiting for {self.np} to start.")
self.driver.wait_for_available_slots(self.np)

# Host-to-host common interface detection
# requires at least 2 hosts in an elastic job.
min_hosts = _get_min_start_hosts(self.settings)
current_hosts = self.driver.wait_for_available_slots(
self.np, min_hosts=min_hosts)
logger.debug("[ray] getting common interfaces")
nics = detect_nics(
self.settings,
all_host_names=current_hosts.host_assignment_order,
)
logger.debug("[ray] getting driver IP")
server_ip = socket.gethostbyname(socket.gethostname())
self.run_env_vars = create_run_env_vars(
server_ip, nics, global_rendezv_port, elastic=True)

self.executable_cls = executable_cls
self.executable_args = executable_args
self.executable_kwargs = executable_kwargs


def _create_resources(self, hostname: str):
resources = dict(
num_cpus=self.cpus_per_worker,
num_gpus=int(self.use_gpu) * self.gpus_per_worker,
resources={f"node:{hostname}": 0.01})
return resources

def _create_remote_worker(self, slot_info, worker_env_vars):
hostname = slot_info.hostname
loaded_worker_cls = self.remote_worker_cls.options(
**self._create_resources(hostname))

worker = loaded_worker_cls.remote()
worker.update_env_vars.remote(worker_env_vars)
worker.update_env_vars.remote(create_slot_env_vars(slot_info))
if self.use_gpu:
visible_devices = ",".join(
[str(i) for i in range(slot_info.local_size)])
worker.update_env_vars.remote({
"CUDA_VISIBLE_DEVICES":
visible_devices
})
return worker

def _create_spawn_worker_fn(self, return_results: List,
worker_fn: Callable,
queue: "ray.util.Queue") -> Callable:
self.remote_worker_cls = ray.remote(BaseHorovodWorker)
# event = register_shutdown_event()
worker_env_vars = {}
worker_env_vars.update(self.run_env_vars.copy())
worker_env_vars.update(self.env_vars.copy())
worker_env_vars.update({"PYTHONUNBUFFERED": "1"})

def worker_loop(slot_info, events):
def ping_worker(worker):
# There is an odd edge case where a node can be removed
# before the remote worker is started, leading to a failure
# in trying to create the horovod mesh.
try:
ping = worker.execute.remote(lambda _: 1)
ray.get(ping, timeout=10)
except Exception as e:
logger.error(f"{slot_info.hostname}: Ping failed - {e}")
return False
return True

worker = self._create_remote_worker(slot_info, worker_env_vars)
if not ping_worker(worker):
return 1, time.time()

ray.get(worker.set_queue.remote(queue))
future = worker.execute.remote(worker_fn)

result = None
while result is None:
try:
# TODO: make this event driven at some point.
retval = ray.get(future, timeout=0.1)
return_results.append((slot_info.rank, retval))
# Success
result = 0, time.time()
except GetTimeoutError:
# Timeout
if any(e.is_set() for e in events):
ray.kill(worker)
result = 1, time.time()
except Exception as e:
logger.error(f"{slot_info.hostname}[{slot_info.rank}]:{e}")
ray.kill(worker)
result = 1, time.time()
logger.debug(f"Worker ({slot_info}) routine is done!")
return result

return worker_loop


def run(self,
fn: Callable[[Any], Any],
args: Optional[List] = None,
kwargs: Optional[Dict] = None,
callbacks: Optional[List[Callable]] = None) -> List[Any]:
"""Executes the provided function on all workers.
Args:
fn: Target function that can be executed with arbitrary
args and keyword arguments.
args: List of arguments to be passed into the target function.
kwargs: Dictionary of keyword arguments to be
passed into the target function.
callbacks: List of callables. Each callback must either
be a callable function or a class that implements __call__.
Every callback will be invoked on every value logged
by the rank 0 worker.
Returns:
Deserialized return values from the target function.
"""
args = args or []
kwargs = kwargs or {}
f = lambda _: fn(*args, **kwargs)
return self._run_remote(f, callbacks=callbacks)

def _run_remote(self,
worker_fn: Callable,
callbacks: Optional[List[Callable]] = None) -> List[Any]:
"""Executes the provided function on all workers.
Args:
worker_fn: Target elastic function that can be executed.
callbacks: List of callables. Each callback must either
be a callable function or a class that implements __call__.
Every callback will be invoked on every value logged
by the rank 0 worker.
Returns:
List of return values from every completed worker.
"""
return_values = []
from ray.util.queue import Queue
import inspect
args = inspect.getfullargspec(Queue).args
if "actor_options" not in args:
# Ray 1.1 and less
_queue = Queue()
else:
_queue = Queue(actor_options={
"num_cpus": 0,
"resources": {
ray.state.current_node_id(): 0.001
}
})
self.driver.start(
self.np,
self._create_spawn_worker_fn(return_values, worker_fn, _queue))

def _process_calls(queue, callbacks, event):
if not callbacks:
return
while queue.actor:
if not queue.empty():
result = queue.get_nowait()
for c in callbacks:
c(result)
# avoid slamming the CI
elif event.is_set():
break
time.sleep(0.1)

try:
event = threading.Event()
_callback_thread = threading.Thread(
target=_process_calls,
args=(_queue, callbacks, event),
daemon=True)
_callback_thread.start()
res = self.driver.get_results()
event.set()
if _callback_thread:
_callback_thread.join(timeout=60)
finally:
if hasattr(_queue, "shutdown"):
_queue.shutdown()
else:
done_ref = _queue.actor.__ray_terminate__.remote()
done, not_done = ray.wait([done_ref], timeout=5)
if not_done:
ray.kill(_queue.actor)
self.driver.stop()

if res.error_message is not None:
raise RuntimeError(res.error_message)

for name, value in sorted(
res.worker_results.items(), key=lambda item: item[1][1]):
exit_code, timestamp = value
if exit_code != 0:
raise RuntimeError(
'Horovod detected that one or more processes '
'exited with non-zero '
'status, thus causing the job to be terminated. '
'The first process '
'to do so was:\nProcess name: {name}\nExit code: {code}\n'
.format(name=name, code=exit_code))

return_values = [
value for k, value in sorted(return_values, key=lambda kv: kv[0])
]
return return_values

def run_remote(self,
fn: Callable[[Any], Any]) -> List[Any]:
raise NotImplementedError("ObjectRefs cannot be returned from Elastic runs as the workers are ephemeral")

def execute(self, fn: Callable[["executable_cls"], Any],
callbacks: Optional[List[Callable]] = None) -> List[Any]:
"""Executes the provided function on all workers.
Args:
fn: Target function to be invoked on every object.
callbacks: List of callables. Each callback must either
be a callable function or a class that implements __call__.
Every callback will be invoked on every value logged
by the rank 0 worker.
Returns:
Deserialized return values from the target function.
"""
return ray.get(self._run_remote(fn, callbacks=callbacks))

0 comments on commit d72be4a

Please sign in to comment.