diff --git a/catalyst/engines/amp.py b/catalyst/engines/amp.py index 1d0f850e3b..871e02cc52 100644 --- a/catalyst/engines/amp.py +++ b/catalyst/engines/amp.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Union + import torch from torch import nn import torch.cuda.amp as amp @@ -10,6 +12,9 @@ class AMPEngine(DeviceEngine): Args: device: used device, default is `"cuda"`. + scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`. + Possible parameters: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler Examples: @@ -41,13 +46,19 @@ def get_engine(self): """ - def __init__(self, device: str = "cuda"): + def __init__(self, device: str = "cuda", scaler_kwargs: Dict[str, Any] = None): """Init.""" super().__init__(device) - self.scaler = amp.GradScaler() + if scaler_kwargs is None: + scaler_kwargs = {} + self.scaler_kwargs = scaler_kwargs + self.scaler = amp.GradScaler(**self.scaler_kwargs) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(device='{self.device}')" + return ( + f"{self.__class__.__name__}(device='{self.device}', " + f"scaler_kwargs={self.scaler_kwargs})" + ) def backward_loss(self, loss, model, optimizer) -> None: """Abstraction over ``loss.backward()`` step.""" @@ -67,6 +78,11 @@ def autocast(self): class DataParallelAMPEngine(AMPEngine): """AMP multi-gpu training device engine. + Args: + scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`. + Possible parameters: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + Examples: .. code-block:: python @@ -96,13 +112,16 @@ def get_engine(self): """ - def __init__(self): + def __init__(self, scaler_kwargs: Dict[str, Any] = None): """Init.""" - super().__init__(f"cuda:{torch.cuda.current_device()}") + super().__init__(f"cuda:{torch.cuda.current_device()}", scaler_kwargs) self.device_count = torch.cuda.device_count() def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(device='{self.device}')" + return ( + f"{self.__class__.__name__}(device='{self.device}', " + f"scaler_kwargs={self.scaler_kwargs})" + ) def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, @@ -133,13 +152,17 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine): """Distributed AMP multi-gpu training device engine. Args: - address: process address to use - (required for PyTorch backend), default is `"localhost"`. - port: process port to listen - (required for PyTorch backend), default is `"12345"`. - backend: multiprocessing backend to use, - default is `"nccl"`. - world_size: number of processes. + address: address to use for backend. + port: port to use for backend. + ddp_kwargs: parameters for `torch.nn.parallel.DistributedDataParallel`. + More info here: + https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel + process_group_kwargs: parameters for `torch.distributed.init_process_group`. + More info here: + https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`. + Possible parameters: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler Examples: @@ -150,7 +173,13 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine): class MyRunner(dl.IRunner): # ... def get_engine(self): - return dl.DistributedDataParallelAMPEngine(port=12345) + return dl.DistributedDataParallelAMPEngine( + address="0.0.0.0", + port=23234, + ddp_kwargs={"find_unused_parameters": False}, + process_group_kwargs={"port": 12345}, + scaler_kwargs={"growth_factor": 1.5} + ) # ... .. code-block:: yaml @@ -164,7 +193,14 @@ def get_engine(self): engine: _target_: DistributedDataParallelAMPEngine - port: 12345 + address: 0.0.0.0 + port: 23234 + ddp_kwargs: + find_unused_parameters: false + process_group_kwargs: + port: 12345 + scaler_kwargs: + growth_factor: 1.5 stages: ... @@ -173,20 +209,31 @@ def get_engine(self): def __init__( self, - address: str = "localhost", - port: str = "12345", - backend: str = "nccl", - world_size: int = None, + address: str = None, + port: Union[str, int] = None, + ddp_kwargs: Dict[str, Any] = None, + process_group_kwargs: Dict[str, Any] = None, + scaler_kwargs: Dict[str, Any] = None, ): """Init.""" - super().__init__(address, port, backend, world_size) - self.scaler = amp.GradScaler() + super().__init__( + address=address, + port=port, + ddp_kwargs=ddp_kwargs, + process_group_kwargs=process_group_kwargs, + ) + if scaler_kwargs is None: + scaler_kwargs = {} + self.scaler_kwargs = scaler_kwargs + self.scaler = amp.GradScaler(**self.scaler_kwargs) def __repr__(self): # noqa: D105 return ( f"{self.__class__.__name__}(address={self.address}, " - f"port={self.port}, backend='{self.backend}'," - f"rank={self._rank}, world_size={self._world_size})" + f"port={self.port}, " + f"ddp_kwargs={self.ddp_kwargs}, " + f"process_group_kwargs={self.process_group_kwargs}, " + f"scaler_kwargs={self.scaler_kwargs})" ) def backward_loss(self, loss, model, optimizer) -> None: diff --git a/catalyst/engines/apex.py b/catalyst/engines/apex.py index c46609afad..eb8cb80c88 100644 --- a/catalyst/engines/apex.py +++ b/catalyst/engines/apex.py @@ -1,8 +1,10 @@ -from typing import Dict, Union +from typing import Any, Dict, Union from collections import OrderedDict +import os import torch from torch import nn +import torch.distributed as dist from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine from catalyst.settings import SETTINGS @@ -126,25 +128,11 @@ class APEXEngine(DeviceEngine): Args: device: use device, default is `"cuda"`. - opt_level: optimization level, should be one of ``"O0"``, - ``"O1"``, ``"O2"`` or ``"O3"``. - - - ``"O0"`` - no-op training - - ``"O1"`` - mixed precision (FP16) training (default) - - ``"O2"`` - "almost" mixed precision training - - ``"O3"`` - another implementation of mixed precision training - - Details about levels can be found here: - https://nvidia.github.io/apex/amp.html#opt-levels - keep_batchnorm_fp32: To enhance precision and enable CUDNN batchnorm - (which improves performance), - it’s often beneficial to keep batchnorm weights in FP32 even - if the rest of the model is FP16. - loss_scale: If loss_scale is a float value, - use this value as the static (fixed) loss scale - If loss_scale is the string "dynamic", - adaptively adjust the loss scale over time. - Dynamic loss scale adjustments are performed by Amp automatically. + apex_kwargs: parameters for `apex.amp.initialize` + except models and optimizers (they will be forwared automatically). + + Docs for `apex.amp.initialize`: + https://nvidia.github.io/apex/amp.html#apex.amp.initialize Examples: @@ -177,21 +165,16 @@ def get_engine(self): """ - def __init__( - self, - device: str = "cuda", - opt_level: str = "O1", - keep_batchnorm_fp32: bool = None, - loss_scale: Union[float, str] = None, - ): + def __init__(self, device: str = "cuda", apex_kwargs: Dict[str, Any] = None): """Init.""" super().__init__(device) - self.opt_level = opt_level - self.keep_batchnorm_fp32 = keep_batchnorm_fp32 - self.loss_scale = loss_scale + if apex_kwargs is None: + apex_kwargs = {} + self.apex_kwargs = apex_kwargs def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(device='{self.device}',opt_level='{self.opt_level}')" + args_list = [f"device='{self.device}'", f"apex_kwargs={self.apex_kwargs}"] + return f"{self.__class__.__name__}(" + ",".join(args_list) + ")" def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, @@ -212,13 +195,7 @@ def init_components( # from official docs: # https://nvidia.github.io/apex/amp.html#opt-levels-and-properties - model, optimizer = _initialize_apex( - model, - optimizer, - opt_level=self.opt_level, - keep_batchnorm_fp32=self.keep_batchnorm_fp32, - loss_scale=self.loss_scale, - ) + model, optimizer = _initialize_apex(model, optimizer, **self.apex_kwargs) # scheduler scheduler = scheduler_fn() @@ -299,16 +276,11 @@ class DataParallelApexEngine(APEXEngine): """Apex multi-gpu training device engine. Args: - opt_level: optimization level, should be one of ``"O0"``, - ``"O1"``, ``"O2"`` or ``"O3"``. - - - ``"O0"`` - no-op training - - ``"O1"`` - mixed precision (FP16) training (default) - - ``"O2"`` - "almost" mixed precision training - - ``"O3"`` - another implementation of mixed precision training + apex_kwargs: parameters for `apex.amp.initialize` + except models and optimizers (they will be forwared automatically). - Details about levels can be found here: - https://nvidia.github.io/apex/amp.html#opt-levels + Docs for `apex.amp.initialize`: + https://nvidia.github.io/apex/amp.html#apex.amp.initialize Examples: @@ -340,13 +312,13 @@ def get_engine(self): """ - def __init__(self, opt_level: str = "O1"): + def __init__(self, apex_kwargs: Dict[str, Any]): """Init.""" - super().__init__(f"cuda:{torch.cuda.current_device()}", opt_level) + super().__init__(f"cuda:{torch.cuda.current_device()}", apex_kwargs) self.device_count = torch.cuda.device_count() def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(device='{self.device}',opt_level='{self.opt_level}')" + return f"{self.__class__.__name__}(device='{self.device}', apex_kwargs={self.apex_kwargs})" def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, @@ -364,7 +336,7 @@ def init_components( optimizer = self.sync_device(optimizer) model, optimizer = _wrap_into_data_parallel_with_apex( - model, optimizer, distributed_params={"opt_level": self.opt_level} + model, optimizer, distributed_params=self.apex_kwargs ) # scheduler @@ -377,35 +349,19 @@ class DistributedDataParallelApexEngine(DistributedDataParallelEngine): """Distributed Apex MultiGPU training device engine. Args: - address: process address to use - (required for PyTorch backend), default is `"localhost"`. - port: process port to listen - (required for PyTorch backend), default is `"12345"`. - backend: multiprocessing backend to use, - default is `"nccl"`. - world_size: number of processes. - opt_level: optimization level, should be one of ``"O0"``, - ``"O1"``, ``"O2"`` or ``"O3"``. - - - ``"O0"`` - no-op training - - ``"O1"`` - mixed precision (FP16) training (default) - - ``"O2"`` - "almost" mixed precision training - - ``"O3"`` - another implementation of mixed precision training - - Details about levels can be found here: - https://nvidia.github.io/apex/amp.html#opt-levels - - keep_batchnorm_fp32: To enhance precision and - enable CUDNN batchnorm (which improves performance), - it’s often beneficial to keep batchnorm weights in FP32 even - if the rest of the model is FP16. - loss_scale: If loss_scale is a float value, - use this value as the static (fixed) loss scale. - If loss_scale is the string "dynamic", - adaptively adjust the loss scale over time. - Dynamic loss scale adjustments are performed by Amp automatically. - delay_all_reduce (bool): boolean flag for delayed all reduce, - default is `True`. + address: address to use for backend. + port: port to use for backend. + ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`. + More info here: + https://nvidia.github.io/apex/parallel.html#apex.parallel.DistributedDataParallel + process_group_kwargs: parameters for `torch.distributed.init_process_group`. + More info here: + https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + apex_kwargs: parameters for `apex.amp.initialize` + except models and optimizers (they will be forwared automatically). + + Docs for `apex.amp.initialize`: + https://nvidia.github.io/apex/amp.html#apex.amp.initialize Examples: @@ -417,8 +373,11 @@ class MyRunner(dl.IRunner): # ... def get_engine(self): return dl.DistributedDataParallelApexEngine( - port=12345, - opt_level="O1" + address="0.0.0.0", + port=23234, + ddp_kwargs={"allreduce_always_fp32": True}, + process_group_kwargs={"backend": "nccl"}, + apex_kwargs={"opt_level": "O1"}, ) # ... @@ -433,8 +392,14 @@ def get_engine(self): engine: _target_: DistributedDataParallelApexEngine - port: 12345 - opt_level: O1 + address: 0.0.0.0 + port: 23234 + ddp_kwargs: + allreduce_always_fp32: true + process_group_kwargs: + backend: nccl + apex_kwargs: + opt_level: O1 stages: ... @@ -442,36 +407,53 @@ def get_engine(self): def __init__( self, - address: str = "localhost", - port: str = "12345", - backend: str = "nccl", - world_size: int = None, - opt_level: str = "O1", - keep_batchnorm_fp32: bool = None, - loss_scale: Union[float, str] = None, - delay_all_reduce: bool = True, + address: str = None, + port: Union[str, int] = None, + ddp_kwargs: Dict[str, Any] = None, + process_group_kwargs: Dict[str, Any] = None, + apex_kwargs: Dict[str, Any] = None, ): """Init.""" - super().__init__() - self.address = address - self.port = port - self.backend = backend - self._rank = 0 - self._world_size = world_size or torch.cuda.device_count() - self.device = None - self.opt_level = opt_level - self.delay_all_reduce = delay_all_reduce - self.keep_batchnorm_fp32 = keep_batchnorm_fp32 - self.loss_scale = loss_scale + super().__init__( + address=address, port=port, ddp_kwargs=None, process_group_kwargs=process_group_kwargs + ) + if ddp_kwargs is None: + ddp_kwargs = {} + self.ddp_kwargs = ddp_kwargs + if apex_kwargs is None: + apex_kwargs = {} + self.apex_kwargs = apex_kwargs def __repr__(self): # noqa: D105 return ( f"{self.__class__.__name__}(address={self.address}, " - f"port={self.port}, backend='{self.backend}', " - f"rank={self._rank}, world_size={self._world_size}, " - f"opt_level='{self.opt_level}')" + f"port={self.port}, " + f"ddp_kwargs={self.ddp_kwargs}, " + f"process_group_kwargs={self.process_group_kwargs}, " + f"apex_kwargs={self.apex_kwargs})" ) + def setup_process(self, rank: int = -1, world_size: int = 1): + """Initialize DDP variables and processes. + + Args: + rank: process rank. Default is `-1`. + world_size: number of devices in netwok to expect for train. + Default is `1`. + """ + self._rank = rank + self._world_size = world_size + + self.process_group_kwargs["rank"] = rank + self.process_group_kwargs["world_size"] = world_size + os.environ["MASTER_ADDR"] = str(self.address) + os.environ["MASTER_PORT"] = str(self.port) + + dist.init_process_group(**self.process_group_kwargs) + + torch.cuda.set_device(int(self._rank)) + self.device = f"cuda:{int(self._rank)}" + def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): @@ -485,14 +467,8 @@ def init_components( optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) - model, optimizer = amp.initialize( - model, - optimizer, - opt_level=self.opt_level, - keep_batchnorm_fp32=self.keep_batchnorm_fp32, - loss_scale=self.loss_scale, - ) - model = ApexDistributedDataParallel(model, delay_allreduce=self.delay_all_reduce) + model, optimizer = amp.initialize(model, optimizer, **self.apex_kwargs) + model = ApexDistributedDataParallel(model, **self.ddp_kwargs) scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) diff --git a/catalyst/engines/tests/test_apex.py b/catalyst/engines/tests/test_apex.py index 83be6b8b5a..be9ba35505 100644 --- a/catalyst/engines/tests/test_apex.py +++ b/catalyst/engines/tests/test_apex.py @@ -45,7 +45,7 @@ def __init__(self, logdir, device, opt_level): self._opt_level = opt_level def get_engine(self): - return APEXEngine(self._device, self._opt_level) + return APEXEngine(self._device, apex_kwargs=dict(opt_level=self._opt_level)) def get_callbacks(self, stage: str): return { @@ -116,7 +116,7 @@ def train_from_config(device, opt_level): "engine": { "_target_": "APEXEngine", "device": device, - "opt_level": opt_level.upper(), + "apex_kwargs": {"opt_level": opt_level.upper()}, }, "args": {"logdir": logdir}, "stages": { diff --git a/catalyst/engines/tests/test_distributed.py b/catalyst/engines/tests/test_distributed.py index 7e328f7346..7165c1c9c7 100644 --- a/catalyst/engines/tests/test_distributed.py +++ b/catalyst/engines/tests/test_distributed.py @@ -41,7 +41,9 @@ def __init__(self, logdir): self._logdir = logdir def get_engine(self): - return DistributedDataParallelEngine(port=DDP_ADDRESS + random.randint(1, 100)) + return DistributedDataParallelEngine( + port=DDP_ADDRESS + random.randint(1, 100), process_group_kwargs={"backend": "nccl"} + ) def get_callbacks(self, stage: str): return { @@ -125,6 +127,7 @@ def test_config_ddp_engine(): "engine": { "_target_": "DistributedDataParallelEngine", "port": DDP_ADDRESS + random.randint(100, 200), + "process_group_kwargs": {"backend": "nccl"}, }, "loggers": {"console": {"_target_": "ConsoleLogger"}}, "stages": { diff --git a/catalyst/engines/tests/test_distributed_amp.py b/catalyst/engines/tests/test_distributed_amp.py index 562000765e..2441f325a0 100644 --- a/catalyst/engines/tests/test_distributed_amp.py +++ b/catalyst/engines/tests/test_distributed_amp.py @@ -43,7 +43,9 @@ def __init__(self, logdir): self._logdir = logdir def get_engine(self): - return DistributedDataParallelAMPEngine(port=DDP_ADDRESS + random.randint(1, 100)) + return DistributedDataParallelAMPEngine( + port=DDP_ADDRESS + random.randint(1, 100), process_group_kwargs={"backend": "nccl"} + ) def get_callbacks(self, stage: str): return { @@ -127,6 +129,7 @@ def test_train_with_config_experiment_distributed_parallel_amp_device(): "engine": { "_target_": "DistributedDataParallelAMPEngine", "port": DDP_ADDRESS + random.randint(100, 200), + "process_group_kwargs": {"backend": "nccl"}, }, "loggers": {"console": {"_target_": "ConsoleLogger"}}, "stages": { diff --git a/catalyst/engines/tests/test_distributed_apex.py b/catalyst/engines/tests/test_distributed_apex.py index 4aa5260331..7de4dbfc9a 100644 --- a/catalyst/engines/tests/test_distributed_apex.py +++ b/catalyst/engines/tests/test_distributed_apex.py @@ -46,7 +46,11 @@ def __init__(self, logdir, opt_level, port="12345"): self._port = port def get_engine(self): - return DistributedDataParallelApexEngine(port=self._port, opt_level=self._opt_level) + return DistributedDataParallelApexEngine( + port=DDP_ADDRESS + random.randint(1, 100), + process_group_kwargs={"backend": "nccl"}, + apex_kwargs=dict(opt_level=self._opt_level), + ) def get_callbacks(self, stage: str): return { @@ -116,8 +120,9 @@ def train_from_config(port, logdir, opt_lvl): "model": {"_target_": "DummyModel", "in_features": 4, "out_features": 2}, "engine": { "_target_": "DistributedDataParallelApexEngine", - "port": port, - "opt_level": opt, + "port": DDP_ADDRESS + random.randint(100, 200), + "process_group_kwargs": {"backend": "nccl"}, + "apex_kwargs": {"opt_level": opt}, }, "loggers": {"console": {"_target_": "ConsoleLogger"}}, "stages": { diff --git a/catalyst/engines/tests/test_parallel_apex.py b/catalyst/engines/tests/test_parallel_apex.py index 6262ed8e6f..c19aed5c13 100644 --- a/catalyst/engines/tests/test_parallel_apex.py +++ b/catalyst/engines/tests/test_parallel_apex.py @@ -44,7 +44,7 @@ def __init__(self, logdir, opt_level): self._opt_level = opt_level def get_engine(self): - return DataParallelApexEngine(self._opt_level) + return DataParallelApexEngine(apex_kwargs=dict(opt_level=self._opt_level)) def get_callbacks(self, stage: str): return { @@ -111,7 +111,10 @@ def train_from_config(opt_level): config={ "args": {"logdir": logdir}, "model": {"_target_": "DummyModel", "in_features": 4, "out_features": 2}, - "engine": {"_target_": "DataParallelApexEngine", "opt_level": opt_level}, + "engine": { + "_target_": "DataParallelApexEngine", + "apex_kwargs": {"opt_level": opt_level}, + }, "args": {"logdir": logdir}, "stages": { "stage1": { diff --git a/catalyst/engines/torch.py b/catalyst/engines/torch.py index 176c9d1c34..99f9daecad 100644 --- a/catalyst/engines/torch.py +++ b/catalyst/engines/torch.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Mapping, Union +import copy import os import torch @@ -264,13 +265,14 @@ class DistributedDataParallelEngine(DeviceEngine): """Distributed MultiGPU training device engine. Args: - address: process address to use - (required for PyTorch backend), default is `"localhost"`. - port: process port to listen - (required for PyTorch backend), default is `"12345"`. - backend: multiprocessing backend to use, - default is `"nccl"`. - world_size: number of processes. + address: address to use for backend. + port: port to use for backend. + ddp_kwargs: parameters for `torch.nn.parallel.DistributedDataParallel`. + More info here: + https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel + process_group_kwargs: parameters for `torch.distributed.init_process_group`. + More info here: + https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group Examples: @@ -281,7 +283,12 @@ class DistributedDataParallelEngine(DeviceEngine): class MyRunner(dl.IRunner): # ... def get_engine(self): - return dl.DistributedDataParallelEngine(port=12345) + return dl.DistributedDataParallelEngine( + address="0.0.0.0", + port=23234, + ddp_kwargs={"find_unused_parameters": False}, + process_group_kwargs={"backend": "nccl"}, + ) # ... .. code-block:: yaml @@ -295,7 +302,12 @@ def get_engine(self): engine: _target_: DistributedDataParallelEngine - port: 12345 + address: 0.0.0.0 + port: 23234 + ddp_kwargs: + find_unused_parameters: false + process_group_kwargs: + backend: nccl stages: ... @@ -304,25 +316,41 @@ def get_engine(self): def __init__( self, - address: str = "localhost", - port: str = "12345", - backend: str = "nccl", - world_size: int = None, + address: str = None, + port: Union[str, int] = None, + ddp_kwargs: Dict[str, Any] = None, + process_group_kwargs: Dict[str, Any] = None, ): """Init.""" super().__init__() - self.address = address - self.port = port - self.backend = backend + self.address = address or "localhost" + self.port = port or 12345 self._rank = 0 - self._world_size = world_size or torch.cuda.device_count() self.device = None + if ddp_kwargs is None: + ddp_kwargs = {} + self.ddp_kwargs = copy.deepcopy(ddp_kwargs) + + if process_group_kwargs is None: + process_group_kwargs = {} + self.process_group_kwargs = copy.deepcopy(process_group_kwargs) + # add missing arguments + if "backend" not in self.process_group_kwargs: + self.process_group_kwargs["backend"] = "nccl" + if "world_size" not in self.process_group_kwargs: + self.process_group_kwargs["world_size"] = torch.cuda.device_count() + + self._world_size = ( + self.process_group_kwargs.get("world_size", None) or torch.cuda.device_count() + ) + def __repr__(self): # noqa: D105 return ( f"{self.__class__.__name__}(address={self.address}, " - f"port={self.port}, backend='{self.backend}'," - f"rank={self._rank}, world_size={self._world_size})" + f"port={self.port}, " + f"ddp_kwargs={self.ddp_kwargs}, " + f"process_group_kwargs={self.process_group_kwargs})" ) @property @@ -365,11 +393,18 @@ def setup_process(self, rank: int = -1, world_size: int = 1): """ self._rank = rank self._world_size = world_size + + self.process_group_kwargs["rank"] = rank + self.process_group_kwargs["world_size"] = world_size os.environ["MASTER_ADDR"] = str(self.address) os.environ["MASTER_PORT"] = str(self.port) - dist.init_process_group(self.backend, rank=self.rank, world_size=self.world_size) + + dist.init_process_group(**self.process_group_kwargs) + torch.cuda.set_device(int(self._rank)) self.device = f"cuda:{int(self._rank)}" + if "device_ids" not in self.ddp_kwargs: + self.ddp_kwargs["device_ids"] = [self.device] def cleanup_process(self): """Clean DDP variables and processes.""" @@ -404,13 +439,10 @@ def init_components( """Inits the runs components.""" model = model_fn() model = self.sync_device(model) - # NOTE: do not forget to wrap a model in DDP if isinstance(model, nn.Module): - model = DistributedDataParallel(model, device_ids=[self.device]) + model = DistributedDataParallel(model, **self.ddp_kwargs) elif isinstance(model, dict): - model = { - k: DistributedDataParallel(v, device_ids=[self.device]) for k, v in model.items() - } + model = {k: DistributedDataParallel(v, **self.ddp_kwargs) for k, v in model.items()} # criterion criterion = criterion_fn() criterion = self.sync_device(criterion)