Skip to content

Commit

Permalink
Engines refactoring (#1243)
Browse files Browse the repository at this point in the history
* +

* +
  • Loading branch information
Scitator committed Jun 22, 2021
1 parent dfb8d32 commit 7f63545
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 265 deletions.
42 changes: 27 additions & 15 deletions catalyst/core/engine.py
@@ -1,8 +1,12 @@
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Tuple, Union
from abc import ABC, abstractmethod
from contextlib import contextmanager

from catalyst.typing import Criterion, Model, Optimizer, Scheduler
import numpy as np
import torch
from torch import nn

from catalyst.typing import Criterion, Device, Model, Optimizer, Scheduler


@contextmanager
Expand All @@ -29,10 +33,11 @@ class IEngine(ABC):
- :py:mod:`catalyst.engines.torch.DeviceEngine`
"""

# @property
# @abstractmethod
# def device(self) -> Device:
# pass
@property
@abstractmethod
def device(self) -> Device:
"""Pytorch device."""
pass

@property
@abstractmethod
Expand Down Expand Up @@ -60,7 +65,9 @@ def is_master_process(self) -> bool:
Returns:
`True` if current process is a master process in other cases return `False`.
"""
return True
# -1 for non-distributed setup
# 0 for distributed setup
return self.rank <= 0

@property
def is_worker_process(self) -> bool:
Expand All @@ -71,7 +78,7 @@ def is_worker_process(self) -> bool:
Returns:
`True` if current process is a worker process in other cases return `False`.
"""
return False
return self.rank > 0

def setup_process(self, rank: int = -1, world_size: int = 1):
"""Initialize DDP variables and processes.
Expand All @@ -88,7 +95,9 @@ def cleanup_process(self):
pass

@abstractmethod
def sync_device(self, tensor_or_module: Any) -> Any:
def sync_device(
self, tensor_or_module: Union[Dict, List, Tuple, np.ndarray, torch.Tensor, nn.Module]
) -> Union[Dict, List, Tuple, torch.Tensor, nn.Module]:
"""Moves ``tensor_or_module`` to Engine's device.
Args:
Expand All @@ -97,25 +106,28 @@ def sync_device(self, tensor_or_module: Any) -> Any:
pass

@abstractmethod
def sync_tensor(self, tensor: Any, mode: str) -> Any:
def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor:
"""Syncs ``tensor`` over ``world_size`` in distributed mode."""
pass

@abstractmethod
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
self,
model_fn: Callable = None,
criterion_fn: Callable = None,
optimizer_fn: Callable = None,
scheduler_fn: Callable = None,
):
"""Inits the runs components."""
pass

# @TODO: create RunnerLike type for .model, .criterion, .optimizer, .scheduler
@abstractmethod
def deinit_components(self, runner=None):
"""Deinits the runs components. In distributed mode should destroy process group."""
pass

@abstractmethod
def zero_grad(self, loss, model, optimizer) -> None:
def zero_grad(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``model.zero_grad()`` step.
Should be overloaded in cases when required to set arguments
for ``model.zero_grad()`` like `set_to_none=True` or
Expand All @@ -130,7 +142,7 @@ def zero_grad(self, loss, model, optimizer) -> None:
pass

@abstractmethod
def backward_loss(self, loss, model, optimizer) -> None:
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step.
Should be overloaded in cases when required loss scaling.
Examples - APEX and AMP.
Expand All @@ -143,7 +155,7 @@ def backward_loss(self, loss, model, optimizer) -> None:
pass

@abstractmethod
def optimizer_step(self, loss, model, optimizer) -> None:
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step.
Should be overloaded in cases when required gradient scaling.
Example - AMP.
Expand Down
102 changes: 32 additions & 70 deletions catalyst/core/runner.py
Expand Up @@ -286,6 +286,28 @@ def hparams(self) -> OrderedDict:
"""
return {}

@property
def _log_defaults(self) -> Dict:
return {
# experiment info
"run_key": self.run_key,
"global_sample_step": self.global_sample_step,
"global_batch_step": self.global_batch_step,
"global_epoch_step": self.global_epoch_step,
# stage info
"stage_key": self.stage_key,
"stage_epoch_len": self.stage_epoch_len,
"stage_epoch_step": self.stage_epoch_step,
"stage_batch_step": self.stage_batch_step,
"stage_sample_step": self.stage_sample_step,
# loader info
"loader_key": self.loader_key,
"loader_batch_len": self.loader_batch_len,
"loader_sample_len": self.loader_sample_len,
"loader_batch_step": self.loader_batch_step,
"loader_sample_step": self.loader_sample_step,
}

@property
@abstractmethod
def stages(self) -> Iterable[str]:
Expand Down Expand Up @@ -514,91 +536,31 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, ICallback]":
"""
return {}

def log_metrics(self, *args, **kwargs) -> None:
"""Logs batch, loader and epoch metrics to available loggers."""
def log_hparams(self, *args, **kwargs) -> None:
"""Logs hyperparameters to available loggers."""
for logger in self.loggers.values():
logger.log_metrics(
logger.log_hparams(
*args,
**kwargs,
# experiment info
run_key=self.run_key,
global_sample_step=self.global_sample_step,
global_batch_step=self.global_batch_step,
global_epoch_step=self.global_epoch_step,
# stage info
stage_key=self.stage_key,
stage_epoch_len=self.stage_epoch_len,
stage_epoch_step=self.stage_epoch_step,
stage_batch_step=self.stage_batch_step,
stage_sample_step=self.stage_sample_step,
# loader info
loader_key=self.loader_key,
loader_batch_len=self.loader_batch_len,
loader_sample_len=self.loader_sample_len,
loader_batch_step=self.loader_batch_step,
loader_sample_step=self.loader_sample_step,
)

def log_image(self, *args, **kwargs) -> None:
"""Logs image to available loggers."""
def log_metrics(self, *args, **kwargs) -> None:
"""Logs batch, loader and epoch metrics to available loggers."""
for logger in self.loggers.values():
logger.log_image(
*args,
**kwargs,
# experiment info
run_key=self.run_key,
global_sample_step=self.global_sample_step,
global_batch_step=self.global_batch_step,
global_epoch_step=self.global_epoch_step,
# stage info
stage_key=self.stage_key,
stage_epoch_len=self.stage_epoch_len,
stage_epoch_step=self.stage_epoch_step,
stage_batch_step=self.stage_batch_step,
stage_sample_step=self.stage_sample_step,
# loader info
loader_key=self.loader_key,
loader_batch_len=self.loader_batch_len,
loader_sample_len=self.loader_sample_len,
loader_batch_step=self.loader_batch_step,
loader_sample_step=self.loader_sample_step,
)
logger.log_metrics(*args, **kwargs, **self._log_defaults)

def log_hparams(self, *args, **kwargs) -> None:
"""Logs hyperparameters to available loggers."""
def log_image(self, *args, **kwargs) -> None:
"""Logs image to available loggers."""
for logger in self.loggers.values():
logger.log_hparams(
*args,
**kwargs,
# experiment info
run_key=self.run_key,
stage_key=self.stage_key,
)
logger.log_image(*args, **kwargs, **self._log_defaults)

def log_artifact(self, *args, **kwargs) -> None:
"""Logs artifact (file like audio, video, csv, etc.) to available loggers."""
for logger in self.loggers.values():
logger.log_artifact(
*args,
**kwargs,
# experiment info
run_key=self.run_key,
global_sample_step=self.global_sample_step,
global_batch_step=self.global_batch_step,
global_epoch_step=self.global_epoch_step,
# stage info
stage_key=self.stage_key,
stage_epoch_len=self.stage_epoch_len,
stage_epoch_step=self.stage_epoch_step,
stage_batch_step=self.stage_batch_step,
stage_sample_step=self.stage_sample_step,
# loader info
loader_key=self.loader_key,
loader_batch_len=self.loader_batch_len,
loader_sample_len=self.loader_sample_len,
loader_batch_step=self.loader_batch_step,
loader_sample_step=self.loader_sample_step,
)
logger.log_artifact(*args, **kwargs, **self._log_defaults)

def flush_log(self) -> None:
"""Flushes the loggers."""
Expand Down
13 changes: 7 additions & 6 deletions catalyst/engines/amp.py
Expand Up @@ -5,6 +5,7 @@
import torch.cuda.amp as amp

from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.typing import Model, Optimizer


class AMPEngine(DeviceEngine):
Expand Down Expand Up @@ -66,15 +67,15 @@ def __init__(self, device: str = "cuda", scaler_kwargs: Dict[str, Any] = None):

def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self.device}', "
f"{self.__class__.__name__}(device='{self._device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)

def backward_loss(self, loss, model, optimizer) -> None:
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()

def optimizer_step(self, loss, model, optimizer) -> None:
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(self, scaler_kwargs: Dict[str, Any] = None):

def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self.device}', "
f"{self.__class__.__name__}(device='{self._device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)

Expand Down Expand Up @@ -268,11 +269,11 @@ def __repr__(self): # noqa: D105
f"scaler_kwargs={self.scaler_kwargs})"
)

def backward_loss(self, loss, model, optimizer) -> None:
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()

def optimizer_step(self, loss, model, optimizer) -> None:
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
Expand Down
14 changes: 8 additions & 6 deletions catalyst/engines/apex.py
Expand Up @@ -8,7 +8,7 @@

from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.settings import SETTINGS
from catalyst.typing import RunnerModel, RunnerOptimizer
from catalyst.typing import Model, Optimizer, RunnerModel, RunnerOptimizer
from catalyst.utils.misc import get_fn_default_params

if SETTINGS.apex_required:
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(self, device: str = "cuda", apex_kwargs: Dict[str, Any] = None):
self.apex_kwargs = apex_kwargs or {}

def __repr__(self) -> str: # noqa: D105
args_list = [f"device='{self.device}'", f"apex_kwargs={self.apex_kwargs}"]
args_list = [f"device='{self._device}'", f"apex_kwargs={self.apex_kwargs}"]
return f"{self.__class__.__name__}(" + ",".join(args_list) + ")"

def init_components(
Expand Down Expand Up @@ -211,7 +211,7 @@ def init_components(
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler

def backward_loss(self, loss, model, optimizer) -> None:
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
Expand Down Expand Up @@ -338,7 +338,9 @@ def __init__(self, apex_kwargs: Dict[str, Any] = None):
self.device_count = torch.cuda.device_count()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}', apex_kwargs={self.apex_kwargs})"
return (
f"{self.__class__.__name__}(device='{self._device}', apex_kwargs={self.apex_kwargs})"
)

def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
Expand Down Expand Up @@ -482,7 +484,7 @@ def setup_process(self, rank: int = -1, world_size: int = 1):
dist.init_process_group(**self.process_group_kwargs)

torch.cuda.set_device(int(self._rank))
self.device = f"cuda:{int(self._rank)}"
self._device = f"cuda:{int(self._rank)}"

def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
Expand All @@ -504,7 +506,7 @@ def init_components(
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler

def backward_loss(self, loss, model, optimizer) -> None:
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
Expand Down

0 comments on commit 7f63545

Please sign in to comment.