Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambda wrapper callback #1153

Merged
merged 38 commits into from Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d2a24d9
lambda wrapper callback
elephantmipt Mar 31, 2021
ed38d9d
lambda wrapper callback
elephantmipt Mar 31, 2021
5b26b8d
CHANGELOG.md
elephantmipt Mar 31, 2021
2f03d29
fix exception
elephantmipt Mar 31, 2021
9ccd861
docs fix
elephantmipt Mar 31, 2021
6b292e0
docs fix 2
elephantmipt Mar 31, 2021
f421a40
update
elephantmipt Apr 1, 2021
1384f86
refactor
elephantmipt Apr 1, 2021
a330663
fix
elephantmipt Apr 1, 2021
d3ee3cf
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 1, 2021
a718c6e
codestyle
elephantmipt Apr 1, 2021
ca98a12
Merge remote-tracking branch 'origin/lambda_wrapper_callback' into la…
elephantmipt Apr 1, 2021
b3be16f
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 4, 2021
84ba558
refactoring
elephantmipt Apr 4, 2021
e03c2e9
docs fix
elephantmipt Apr 4, 2021
7e8bbe6
refactoring
elephantmipt Apr 5, 2021
f9fb3bc
refactoring
elephantmipt Apr 5, 2021
1b133e4
replace batch_transform.py
elephantmipt Apr 5, 2021
aa73ab6
fix test
elephantmipt Apr 5, 2021
fa32a42
fix test
elephantmipt Apr 5, 2021
6993a1e
fix exceptions
elephantmipt Apr 5, 2021
a77667e
fix multi-line string
elephantmipt Apr 5, 2021
09a1ca0
codestyle
elephantmipt Apr 5, 2021
90b7fed
fix
elephantmipt Apr 5, 2021
3b810a2
simplification
elephantmipt Apr 14, 2021
448d375
simplification
elephantmipt Apr 14, 2021
3d4a274
docs
elephantmipt Apr 14, 2021
9af4a3a
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 14, 2021
517d157
fix dict keys
elephantmipt Apr 14, 2021
28afff3
naming
elephantmipt Apr 14, 2021
a6767e7
fix args in docs
elephantmipt Apr 14, 2021
0289c6f
simplification
elephantmipt Apr 14, 2021
1b6ce3d
args for SupervisedRunner in test
elephantmipt Apr 14, 2021
1c19966
args in README for RecSys
elephantmipt Apr 14, 2021
f9cc2e4
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 15, 2021
1691c9d
fix
elephantmipt Apr 15, 2021
2d9df88
add kornia example
elephantmipt Apr 15, 2021
04ee2a7
fix docs
elephantmipt Apr 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- LambdaWrapperCallback ([#1153](https://github.com/catalyst-team/catalyst/issues/1153))

### Changed

Expand Down
1 change: 1 addition & 0 deletions catalyst/callbacks/__init__.py
Expand Up @@ -17,6 +17,7 @@
from catalyst.callbacks.checkpoint import ICheckpointCallback, CheckpointCallback
from catalyst.callbacks.control_flow import ControlFlowCallback
from catalyst.callbacks.criterion import ICriterionCallback, CriterionCallback
from catalyst.callbacks.lambda_preprocess import LambdaPreprocessCallback
from catalyst.callbacks.metric import (
BatchMetricCallback,
IMetricCallback,
Expand Down
146 changes: 146 additions & 0 deletions catalyst/callbacks/lambda_preprocess.py
@@ -0,0 +1,146 @@
from typing import Callable, List, Union

from catalyst.core import Callback, CallbackOrder


class LambdaPreprocessCallback(Callback):
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
Preprocess your batch with specified function.

Args:
lambda_fn (Callable): Function to apply.
keys_to_apply (Union[List[str], str], optional): Keys in batch dict to apply function.
Defaults to ["s_hidden_states", "t_hidden_states"].

Raises:
TypeError: When keys_to_apply is not str or list.

Examples:
.. code-block:: python

import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl

# sample data
num_users, num_features, num_items = int(1e4), int(1e1), 10
X = torch.rand(num_users, num_features)
y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_items)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

# model training
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
num_epochs=3,
verbose=True,
callbacks=[
dl.LambdaPreprocessCallback(keys_to_apply="logits", output_keys="scores", lambda_fn=torch.sigmoid)
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
# uncomment for extra metrics:
# dl.AUCCallback(input_key="scores", target_key="targets"),
# dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
# dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
# dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
# dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
dl.OptimizerCallback(metric_key="loss"),
dl.SchedulerCallback(),
dl.CheckpointCallback(
logdir="./logs", loader_key="valid", metric_key="map01", minimize=False
),
]
)

"""

def __init__(
self,
lambda_fn: Callable,
keys_to_apply: Union[List[str], str] = "logits",
output_keys: Union[List[str], str] = None,
):
"""Wraps input for your callback with specified function.
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved

Args:
lambda_fn (Callable): Function to apply.
keys_to_apply (Union[List[str], str], optional): Keys in batch dict to apply function.
Defaults to ["s_hidden_states", "t_hidden_states"].
output_keys (Union[List[str], str], optional): Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to None.

Raises:
TypeError: When keys_to_apply is not str or list.
"""
super().__init__(order=CallbackOrder.Internal)
if not isinstance(keys_to_apply, (list, str)):
raise TypeError("keys to apply should be str or list of str.")
if output_keys is not None:
if not isinstance(output_keys, (list, str)):
raise TypeError("output keys should be str or list of str.")
self.keys_to_apply = keys_to_apply
self.output_keys = output_keys
self.lambda_fn = lambda_fn

def on_batch_end(self, runner) -> None:
"""
On batch end action.

Args:
runner: runner for the experiment.

Raises:
TypeError: If lambda_fn output has a wrong type.

"""
batch = runner.batch

if isinstance(self.keys_to_apply, list):
fn_inp = [batch[key] for key in self.keys_to_apply]
fn_output = self.lambda_fn(*fn_inp)
if isinstance(fn_output, tuple):
if self.output_keys is not None:
if not isinstance(self.output_keys, list):
raise TypeError(
"Unexpected output from function. "
"For output key type string expected one element, got tuple."
)
iter_keys = self.output_keys
else:
iter_keys = self.keys_to_apply
for idx, key in enumerate(iter_keys):
batch[key] = fn_output[idx]
elif isinstance(fn_output, dict):
for outp_k, outp_v in fn_output.items():
batch[outp_k] = outp_v
else:
if self.output_keys is not None:
if not isinstance(self.output_keys, str):
raise TypeError(
"Unexpected output from function. "
"For output key type List[str] expected tuple, got one element."
)
output_key = self.output_keys
else:
output_key = self.keys_to_apply
batch[output_key] = fn_output
elif isinstance(self.keys_to_apply, str):
batch[self.keys_to_apply] = self.lambda_fn(self.keys_to_apply)
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
runner.batch = batch


__all__ = ["LambdaPreprocessCallback"]
7 changes: 7 additions & 0 deletions docs/api/callbacks.rst
Expand Up @@ -50,6 +50,13 @@ CriterionCallback
:exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end
:show-inheritance:

LambdaWrapperCallback
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.callbacks.lambda_preprocess.LambdaPreprocessCallback
:members:
:exclude-members: __init__, on_experiment_start, on_stage_start, on_epoch_start, on_loader_start, on_batch_start, on_batch_end, on_loader_end, on_epoch_end, on_stage_end, on_experiment_end
:show-inheritance:

Metric – BatchMetricCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.callbacks.metric.BatchMetricCallback
Expand Down