Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Engine: kwargs #1156

Merged
merged 10 commits into from Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 70 additions & 23 deletions catalyst/engines/amp.py
@@ -1,3 +1,5 @@
from typing import Any, Dict, Union

import torch
from torch import nn
import torch.cuda.amp as amp
Expand All @@ -10,6 +12,9 @@ class AMPEngine(DeviceEngine):

Args:
device: used device, default is `"cuda"`.
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler

Examples:

Expand Down Expand Up @@ -41,13 +46,19 @@ def get_engine(self):

"""

def __init__(self, device: str = "cuda"):
def __init__(self, device: str = "cuda", scaler_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(device)
self.scaler = amp.GradScaler()
if scaler_kwargs is None:
scaler_kwargs = {}
self.scaler_kwargs = scaler_kwargs
self.scaler = amp.GradScaler(**self.scaler_kwargs)

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}')"
return (
f"{self.__class__.__name__}(device='{self.device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)

def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
Expand All @@ -67,6 +78,11 @@ def autocast(self):
class DataParallelAMPEngine(AMPEngine):
"""AMP multi-gpu training device engine.

Args:
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler

Examples:

.. code-block:: python
Expand Down Expand Up @@ -96,13 +112,16 @@ def get_engine(self):

"""

def __init__(self):
def __init__(self, scaler_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}")
super().__init__(f"cuda:{torch.cuda.current_device()}", scaler_kwargs)
self.device_count = torch.cuda.device_count()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}')"
return (
f"{self.__class__.__name__}(device='{self.device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)

def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
Expand Down Expand Up @@ -133,13 +152,17 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
"""Distributed AMP multi-gpu training device engine.

Args:
address: process address to use
(required for PyTorch backend), default is `"localhost"`.
port: process port to listen
(required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use,
default is `"nccl"`.
world_size: number of processes.
address: address to use for backend.
port: port to use for backend.
ddp_kwargs: parameters for `torch.nn.parallel.DistributedDataParallel`.
More info here:
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel
process_group_kwargs: parameters for `torch.distributed.init_process_group`.
More info here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler

Examples:

Expand All @@ -150,7 +173,13 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelAMPEngine(port=12345)
return dl.DistributedDataParallelAMPEngine(
address="0.0.0.0",
port=23234,
ddp_kwargs={"find_unused_parameters": False},
process_group_kwargs={"port": 12345},
scaler_kwargs={"growth_factor": 1.5}
)
# ...

.. code-block:: yaml
Expand All @@ -164,7 +193,14 @@ def get_engine(self):

engine:
_target_: DistributedDataParallelAMPEngine
port: 12345
address: 0.0.0.0
port: 23234
ddp_kwargs:
find_unused_parameters: false
process_group_kwargs:
port: 12345
scaler_kwargs:
growth_factor: 1.5

stages:
...
Expand All @@ -173,20 +209,31 @@ def get_engine(self):

def __init__(
self,
address: str = "localhost",
port: str = "12345",
backend: str = "nccl",
world_size: int = None,
address: str = None,
port: Union[str, int] = None,
ddp_kwargs: Dict[str, Any] = None,
process_group_kwargs: Dict[str, Any] = None,
scaler_kwargs: Dict[str, Any] = None,
):
"""Init."""
super().__init__(address, port, backend, world_size)
self.scaler = amp.GradScaler()
super().__init__(
address=address,
port=port,
ddp_kwargs=ddp_kwargs,
process_group_kwargs=process_group_kwargs,
)
if scaler_kwargs is None:
scaler_kwargs = {}
self.scaler_kwargs = scaler_kwargs
self.scaler = amp.GradScaler(**self.scaler_kwargs)

def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, backend='{self.backend}',"
f"rank={self._rank}, world_size={self._world_size})"
f"port={self.port}, "
f"ddp_kwargs={self.ddp_kwargs}, "
f"process_group_kwargs={self.process_group_kwargs}, "
f"scaler_kwargs={self.scaler_kwargs})"
)

def backward_loss(self, loss, model, optimizer) -> None:
Expand Down