Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] simplified runner proposal #984

Merged
merged 11 commits into from Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 5 additions & 11 deletions catalyst/callbacks/checkpoint.py
Expand Up @@ -25,7 +25,7 @@ def _pack_runner(runner: "IRunner"):
scheduler=runner.scheduler,
epoch_metrics=dict(runner.epoch_metrics),
valid_metrics=dict(runner.valid_metrics),
stage_name=runner.stage_name,
stage_name=runner.stage,
epoch=runner.epoch,
loader_name=runner.loader_name,
loader_step=runner.loader_batch_step,
Expand Down Expand Up @@ -65,8 +65,8 @@ def _load_checkpoint(
print(f"=> Loading checkpoint {filename}")
checkpoint = load_checkpoint(filename)

if not runner.stage_name.startswith("infer") and load_full:
runner.stage_name = checkpoint["stage_name"]
if not runner.stage.startswith("infer") and load_full:
runner.stage = checkpoint["stage_name"]
runner.epoch = checkpoint["epoch"]
runner.global_epoch = checkpoint["global_epoch"]
# @TODO: should we also load,
Expand Down Expand Up @@ -620,10 +620,7 @@ def on_epoch_end(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
if (
runner.stage_name.startswith("infer")
or runner.is_distributed_worker
):
if runner.stage.startswith("infer") or runner.is_distributed_worker:
return

if self.save_n_best > 0:
Expand All @@ -644,10 +641,7 @@ def on_stage_end(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
if (
runner.stage_name.startswith("infer")
or runner.is_distributed_worker
):
if runner.stage.startswith("infer") or runner.is_distributed_worker:
return
log_message = "Top best models:\n"
# store latest state
Expand Down
2 changes: 1 addition & 1 deletion catalyst/callbacks/control_flow.py
Expand Up @@ -378,7 +378,7 @@ def on_loader_start(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
stage = runner.stage_name
stage = runner.stage
loader = runner.loader_name
epoch = runner.global_epoch if self.use_global_epochs else runner.epoch

Expand Down
5 changes: 3 additions & 2 deletions catalyst/callbacks/criterion.py
@@ -1,6 +1,7 @@
from typing import Dict, List, TYPE_CHECKING, Union

from catalyst.callbacks.metric import IBatchMetricCallback
from catalyst.utils.misc import get_attr

if TYPE_CHECKING:
from catalyst.core.runner import IRunner
Expand Down Expand Up @@ -55,8 +56,8 @@ def on_stage_start(self, runner: "IRunner"):
Args:
runner: current runner
"""
criterion = runner.get_attr(
key="criterion", inner_key=self.criterion_key
criterion = get_attr(
runner, key="criterion", inner_key=self.criterion_key
)
assert criterion is not None
self._criterion = criterion
Expand Down
2 changes: 1 addition & 1 deletion catalyst/callbacks/early_stop.py
Expand Up @@ -120,7 +120,7 @@ def on_epoch_end(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
if runner.stage_name.startswith("infer"):
if runner.stage.startswith("infer"):
return

score = runner.valid_metrics[self.metric]
Expand Down
10 changes: 5 additions & 5 deletions catalyst/callbacks/optimizer.py
Expand Up @@ -7,7 +7,7 @@
from catalyst import registry
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.typing import Optimizer
from catalyst.utils.misc import maybe_recursive_call
from catalyst.utils.misc import get_attr, maybe_recursive_call
from catalyst.utils.torch import get_optimizer_momentum

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,8 +149,8 @@ def on_stage_start(self, runner: "IRunner") -> None:
Args:
runner(IRunner): current runner
"""
self._optimizer = runner.get_attr(
key="optimizer", inner_key=self.optimizer_key
self._optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
# device based optimization step
if runner.device.type == "xla":
Expand Down Expand Up @@ -326,8 +326,8 @@ def on_stage_start(self, runner: "IRunner") -> None:
"""
from torch.cuda.amp import GradScaler

self._optimizer = runner.get_attr(
key="optimizer", inner_key=self.optimizer_key
self._optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
self.scaler = GradScaler()
assert self._optimizer is not None
Expand Down
9 changes: 5 additions & 4 deletions catalyst/callbacks/scheduler.py
Expand Up @@ -5,6 +5,7 @@

from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.utils.misc import get_attr
from catalyst.utils.torch import get_optimizer_momentum

if TYPE_CHECKING:
Expand Down Expand Up @@ -173,8 +174,8 @@ def on_stage_start(self, runner: "IRunner") -> None:
"""
self.reduced_metric = self.reduced_metric or runner.main_metric

scheduler = runner.get_attr(
key="scheduler", inner_key=self.scheduler_key
scheduler = get_attr(
runner, key="scheduler", inner_key=self.scheduler_key
)
assert scheduler is not None
self._scheduler = scheduler
Expand Down Expand Up @@ -297,8 +298,8 @@ def on_stage_start(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
optimizer = runner.get_attr(
key="optimizer", inner_key=self.optimizer_key
optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
assert optimizer is not None
self._optimizer = optimizer
Expand Down
2 changes: 1 addition & 1 deletion catalyst/callbacks/validation.py
Expand Up @@ -33,7 +33,7 @@ def on_epoch_end(self, runner: "IRunner") -> None:
Args:
runner: current runner
"""
if runner.stage_name.startswith("infer"):
if runner.stage.startswith("infer"):
return

runner.valid_metrics = {
Expand Down
4 changes: 2 additions & 2 deletions catalyst/contrib/callbacks/telegram_logger.py
Expand Up @@ -76,7 +76,7 @@ def _send_text(self, text: str):
def on_stage_start(self, runner: "IRunner"):
"""Notify about starting a new stage."""
if self.log_on_stage_start:
text = f"{runner.stage_name} stage was started"
text = f"{runner.stage} stage was started"

self._send_text(text)

Expand Down Expand Up @@ -115,7 +115,7 @@ def on_loader_end(self, runner: "IRunner"):
def on_stage_end(self, runner: "IRunner"):
"""Notify about finishing a stage."""
if self.log_on_stage_end:
text = f"{runner.stage_name} stage was finished"
text = f"{runner.stage} stage was finished"

self._send_text(text)

Expand Down
164 changes: 91 additions & 73 deletions catalyst/core/callback.py
Expand Up @@ -5,6 +5,96 @@
from catalyst.core.runner import IRunner


class ICallback:
def on_experiment_start(self, runner: "IRunner"):
"""Event handler for stage start.

Args:
runner: IRunner instance.
"""
pass

def on_stage_start(self, runner: "IRunner"):
"""Event handler for stage start.

Args:
runner: IRunner instance.
"""
pass

def on_epoch_start(self, runner: "IRunner"):
"""Event handler for epoch start.

Args:
runner: IRunner instance.
"""
pass

def on_loader_start(self, runner: "IRunner"):
"""Event handler for loader start.

Args:
runner: IRunner instance.
"""
pass

def on_batch_start(self, runner: "IRunner"):
"""Event handler for batch start.

Args:
runner: IRunner instance.
"""
pass

def on_batch_end(self, runner: "IRunner"):
"""Event handler for batch end.

Args:
runner: IRunner instance.
"""
pass

def on_loader_end(self, runner: "IRunner"):
"""Event handler for loader end.

Args:
runner: IRunner instance.
"""
pass

def on_epoch_end(self, runner: "IRunner"):
"""Event handler for epoch end.

Args:
runner: IRunner instance.
"""
pass

def on_stage_end(self, runner: "IRunner"):
"""Event handler for stage end.

Args:
runner: IRunner instance.
"""
pass

def on_experiment_end(self, runner: "IRunner"):
"""Event handler for stage start.

Args:
runner: IRunner instance.
"""
pass

def on_exception(self, runner: "IRunner"):
"""Event handler for exception case.

Args:
runner: IRunner instance.
"""
pass


class CallbackNode(IntFlag):
"""Callback node usage flag during distributed training.

Expand Down Expand Up @@ -78,7 +168,7 @@ class CallbackScope(IntFlag):
Experiment = experiment = 1 # noqa: WPS115


class Callback:
class Callback(ICallback):
"""
An abstraction that lets you customize your experiment run logic.
To give users maximum flexibility and extensibility Catalyst supports
Expand Down Expand Up @@ -136,78 +226,6 @@ def __init__(
self.order = order
self.scope = scope

def on_stage_start(self, runner: "IRunner"):
"""Event handler for stage start.

Args:
runner: IRunner instance.
"""
pass

def on_stage_end(self, runner: "IRunner"):
"""Event handler for stage end.

Args:
runner: IRunner instance.
"""
pass

def on_epoch_start(self, runner: "IRunner"):
"""Event handler for epoch start.

Args:
runner: IRunner instance.
"""
pass

def on_epoch_end(self, runner: "IRunner"):
"""Event handler for epoch end.

Args:
runner: IRunner instance.
"""
pass

def on_loader_start(self, runner: "IRunner"):
"""Event handler for loader start.

Args:
runner: IRunner instance.
"""
pass

def on_loader_end(self, runner: "IRunner"):
"""Event handler for loader end.

Args:
runner: IRunner instance.
"""
pass

def on_batch_start(self, runner: "IRunner"):
"""Event handler for batch start.

Args:
runner: IRunner instance.
"""
pass

def on_batch_end(self, runner: "IRunner"):
"""Event handler for batch end.

Args:
runner: IRunner instance.
"""
pass

def on_exception(self, runner: "IRunner"):
"""Event handler for exception case.

Args:
runner: IRunner instance.
"""
pass


class CallbackWrapper(Callback):
"""Enable/disable callback execution."""
Expand Down