Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Engine docs #1141

Merged
merged 8 commits into from Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 28 additions & 14 deletions catalyst/core/engine.py
Expand Up @@ -6,7 +6,7 @@


@contextmanager
def nullcontext(enter_result=None):
def nullcontext(enter_result: Any = None):
"""Context handler."""
yield enter_result

Expand Down Expand Up @@ -37,8 +37,7 @@ def rank(self) -> int:
@property
@abstractmethod
def world_size(self) -> int:
"""Process world size for distributed training."""
# only for ddp
"""Process world size for distributed training."""
pass

@property
Expand All @@ -49,26 +48,28 @@ def is_ddp(self) -> bool:
@property
def is_master_process(self) -> bool:
"""Checks if a process is master process.
Should be implemented only for DDP setup in other cases should always return True.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `True`.

Returns:
`True` if current process is a master process, otherwise `False`.
`True` if current process is a master process in other cases return `False`.
"""
return True

@property
def is_worker_process(self) -> bool:
"""Checks if a process is worker process.
Should be implemented only for DDP setup in other cases should always return False.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `False`.

Returns:
`True` if current process is a worker process, otherwise `False`.
`True` if current process is a worker process in other cases return `False`.
"""
return False

@abstractmethod
def sync_device(self, tensor_or_module: Any) -> Any:
"""Moves ``tensor_or_module`` to Engine's deivce.
"""Moves ``tensor_or_module`` to Engine's device.

Args:
tensor_or_module: tensor to mode
Expand All @@ -89,23 +90,35 @@ def init_components(

@abstractmethod
def deinit_components(self):
"""Deinits the runs components."""
# only for ddp
"""Deinits the runs components.
In distributed mode should destroy process group.
"""
pass

@abstractmethod
def zero_grad(self, loss, model, optimizer) -> None:
"""Abstraction over ``model.zero_grad()`` step."""
"""Abstraction over ``model.zero_grad()`` step.
Should be overloaded in cases when required to set arguments
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
for ``model.zero_grad()`` like `set_to_none=True` or
you need to use custom scheme which replaces/improves
`.zero_grad()` method.
"""
pass

@abstractmethod
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
"""Abstraction over ``loss.backward()`` step.
Should be overloaded in cases when required loss scaling.
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
Examples - APEX and AMP.
"""
pass

@abstractmethod
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
"""Abstraction over ``optimizer.step()`` step.
Should be overloaded in cases when required gradient scaling.
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
Example - AMP.
"""
pass

@abstractmethod
Expand Down Expand Up @@ -174,7 +187,8 @@ def load_checkpoint(self, path: str) -> Dict:
pass

def autocast(self, *args, **kwargs):
"""AMP scaling context. Default autocast context does not scale anything.
"""AMP scaling context.
Default autocast context does not scale anything.

Args:
*args: some args
Expand Down
98 changes: 94 additions & 4 deletions catalyst/engines/amp.py
Expand Up @@ -10,6 +10,35 @@ class AMPEngine(DeviceEngine):

Args:
device: used device, default is `"cuda"`.

Examples:

.. code-block:: python

from catalyst import dl

class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.AMPEngine("cuda:1")
# ...

.. code-block:: yaml

args:
logs: ...

model:
_target_: ...
...

engine:
_target_: AMPEngine
device: cuda:1

stages:
...

"""

def __init__(self, device: str = "cuda"):
Expand All @@ -36,7 +65,36 @@ def autocast(self):


class DataParallelAMPEngine(AMPEngine):
"""AMP multi-gpu training device engine."""
"""AMP multi-gpu training device engine.

Examples:

.. code-block:: python

from catalyst import dl

class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelAMPEngine()
# ...

.. code-block:: yaml

args:
logs: ...

model:
_target_: ...
...

engine:
_target_: DataParallelAMPEngine

stages:
...

"""

def __init__(self):
"""Init."""
Expand Down Expand Up @@ -75,10 +133,42 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
"""Distributed AMP multi-gpu training device engine.

Args:
address: process address to use (required for PyTorch backend), default is `"localhost"`.
port: process port to listen (required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use, default is `"nccl"`.
address: process address to use
(required for PyTorch backend), default is `"localhost"`.
port: process port to listen
(required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use,
default is `"nccl"`.
world_size: number of processes.

Examples:

.. code-block:: python

from catalyst import dl

class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelAMPEngine(port=12345)
# ...

.. code-block:: yaml

args:
logs: ...

model:
_target_: ...
...

engine:
_target_: DistributedDataParallelAMPEngine
port: 12345

stages:
...

"""

def __init__(
Expand Down