Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambda wrapper callback #1153

Merged
merged 38 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d2a24d9
lambda wrapper callback
elephantmipt Mar 31, 2021
ed38d9d
lambda wrapper callback
elephantmipt Mar 31, 2021
5b26b8d
CHANGELOG.md
elephantmipt Mar 31, 2021
2f03d29
fix exception
elephantmipt Mar 31, 2021
9ccd861
docs fix
elephantmipt Mar 31, 2021
6b292e0
docs fix 2
elephantmipt Mar 31, 2021
f421a40
update
elephantmipt Apr 1, 2021
1384f86
refactor
elephantmipt Apr 1, 2021
a330663
fix
elephantmipt Apr 1, 2021
d3ee3cf
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 1, 2021
a718c6e
codestyle
elephantmipt Apr 1, 2021
ca98a12
Merge remote-tracking branch 'origin/lambda_wrapper_callback' into la…
elephantmipt Apr 1, 2021
b3be16f
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 4, 2021
84ba558
refactoring
elephantmipt Apr 4, 2021
e03c2e9
docs fix
elephantmipt Apr 4, 2021
7e8bbe6
refactoring
elephantmipt Apr 5, 2021
f9fb3bc
refactoring
elephantmipt Apr 5, 2021
1b133e4
replace batch_transform.py
elephantmipt Apr 5, 2021
aa73ab6
fix test
elephantmipt Apr 5, 2021
fa32a42
fix test
elephantmipt Apr 5, 2021
6993a1e
fix exceptions
elephantmipt Apr 5, 2021
a77667e
fix multi-line string
elephantmipt Apr 5, 2021
09a1ca0
codestyle
elephantmipt Apr 5, 2021
90b7fed
fix
elephantmipt Apr 5, 2021
3b810a2
simplification
elephantmipt Apr 14, 2021
448d375
simplification
elephantmipt Apr 14, 2021
3d4a274
docs
elephantmipt Apr 14, 2021
9af4a3a
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 14, 2021
517d157
fix dict keys
elephantmipt Apr 14, 2021
28afff3
naming
elephantmipt Apr 14, 2021
a6767e7
fix args in docs
elephantmipt Apr 14, 2021
0289c6f
simplification
elephantmipt Apr 14, 2021
1b6ce3d
args for SupervisedRunner in test
elephantmipt Apr 14, 2021
1c19966
args in README for RecSys
elephantmipt Apr 14, 2021
f9cc2e4
Merge branch 'master' into lambda_wrapper_callback
elephantmipt Apr 15, 2021
1691c9d
fix
elephantmipt Apr 15, 2021
2d9df88
add kornia example
elephantmipt Apr 15, 2021
04ee2a7
fix docs
elephantmipt Apr 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- LambdaWrapperCallback ([#1153](https://github.com/catalyst-team/catalyst/issues/1153))
- Nifti Reader (NiftiReader) ([#1151](https://github.com/catalyst-team/catalyst/pull/1151))

### Changed
Expand Down
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,8 @@ criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

class CustomRunner(dl.Runner):
def handle_batch(self, batch):
x, y = batch
logits = self.model(x)
self.batch = {
"features": x, "logits": logits, "scores": torch.sigmoid(logits), "targets": y
}

# model training
runner = CustomRunner()
runner = dl.SupervisedRunner(output_key="logits")
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
runner.train(
model=model,
criterion=criterion,
Expand All @@ -474,6 +466,12 @@ runner.train(
num_epochs=3,
verbose=True,
callbacks=[
dl.BatchTransformCallback(
lambda_fn=torch.sigmoid,
scope="on_batch_end",
input_key="logits",
output_key="scores"
),
dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
# uncomment for extra metrics:
# dl.AUCCallback(input_key="scores", target_key="targets"),
Expand Down
325 changes: 176 additions & 149 deletions catalyst/callbacks/batch_transform.py
Original file line number Diff line number Diff line change
@@ -1,171 +1,198 @@
from typing import Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Tuple, Union
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved

from torch import nn

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.registry import REGISTRY

if TYPE_CHECKING:
from catalyst.core.runner import IRunner
from catalyst.core import Callback, CallbackOrder, IRunner


class BatchTransformCallback(Callback):
"""Callback to perform data augmentations on GPU using kornia library.
"""
Preprocess your batch with specified function.

Args:
transform: define augmentations to apply on a batch

If a sequence of transforms passed, then each element
should be either ``kornia.augmentation.AugmentationBase2D``,
``kornia.augmentation.AugmentationBase3D``, or ``nn.Module``
compatible with kornia interface.

If a sequence of params (``dict``) passed, then each
element of the sequence must contain ``'transform'`` key with
an augmentation name as a value. Please note that in this case
to use custom augmentation you should add it to the
`REGISTRY` registry first.
input_key (Union[str, int]): key in batch dict mapping to transform, e.g. `'image'`
output_key: key to use to store the result
of the transform, defaults to `input_key` if not provided

Look at `Kornia: an Open Source Differentiable Computer Vision
Library for PyTorch`_ for details.
lambda_fn (Callable): Function to apply.
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
scope (str): ``"on_batch_end"`` or ``"on_batch_start"``
input_key (Union[List[str], str, int], optional): Keys in batch dict to apply function.
Defaults to ``None``.
output_key (Union[List[str], str, int], optional): Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.

Raises:
TypeError: When keys_to_apply is not str or list.

Examples:
.. code-block:: python

import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl

# sample data
num_users, num_features, num_items = int(1e4), int(1e1), 10
X = torch.rand(num_users, num_features)
y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_items)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

# model training
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
num_epochs=3,
verbose=True,
callbacks=[
dl.LambdaPreprocessCallback(
input_keys="logits", output_keys="scores", lambda_fn=torch.sigmoid
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
),
dl.CriterionCallback(
input_key="logits", target_key="targets", metric_key="loss"
),
# uncomment for extra metrics:
# dl.AUCCallback(
# input_key="scores", target_key="targets"
# ),
# dl.HitrateCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
# dl.MRRCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
# dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
# dl.NDCGCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
dl.OptimizerCallback(metric_key="loss"),
dl.SchedulerCallback(),
dl.CheckpointCallback(
logdir="./logs", loader_key="valid", metric_key="map01", minimize=False
),
]
)

Usage example for notebook API:

.. code-block:: python

import os

from kornia import augmentation
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

from catalyst import dl
from catalyst.contrib.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.callbacks.kornia_transform import (
BatchTransformCallback
)
from catalyst import metrics


class CustomRunner(dl.Runner):
def predict_batch(self, batch):
# model inference step
return self.model(
batch[0].to(self.device).view(batch[0].size(0), -1)
)

def handle_batch(self, batch):
# model train/valid step
x, y = batch
y_hat = self.model(x.view(x.size(0), -1))

loss = F.cross_entropy(y_hat, y)
accuracy01, *_ = metrics.accuracy(y_hat, y)
self.batch_metrics.update(
{"loss": loss, "accuracy01": accuracy01}
)

if self.is_train_loader:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()

model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

loaders = {
"train": DataLoader(
MNIST(os.getcwd(), train=True, transform=ToTensor()),
batch_size=32,
),
"valid": DataLoader(
MNIST(os.getcwd(), train=False, transform=ToTensor()),
batch_size=32,
),
}
transforms = [
augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)),
]

runner = CustomRunner()

# model training
runner.train(
model=model,
optimizer=optimizer,
loaders=loaders,
logdir="./logs",
num_epochs=5,
verbose=True,
callbacks=[BatchTransformCallback(transforms, input_key=0)],
)

To apply augmentations only during specific loader e.g. only during
training :class:`catalyst.core.callbacks.control_flow.ControlFlowCallback`
callback can be used. For config API it can look like this:

.. code-block:: yaml

callbacks_params:
...
train_transforms:
_wrapper:
name: ControlFlowCallback
loaders: train
name: BatchTransformCallback
transform:
- transform: kornia.RandomAffine
degrees: [-15, 20]
scale: [0.75, 1.25]
return_transform: true
- transform: kornia.ColorJitter
brightness: 0.1
contrast: 0.1
saturation: 0.1
return_transform: false
input_key: image
...

.. _`Kornia: an Open Source Differentiable Computer Vision Library
for PyTorch`: https://arxiv.org/pdf/1910.02190.pdf
"""

def __init__(
self,
transform: Sequence[Union[dict, nn.Module]],
input_key: Union[str, int] = "image",
output_key: Optional[Union[str, int]] = None,
) -> None:
"""Init."""
super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all)
lambda_fn: Callable,
scope: str,
input_key: Union[List[str], str, int] = None,
output_key: Union[List[str], str, int] = None,
):
"""
Preprocess your batch with specified function.

Args:
lambda_fn (Callable): Function to apply.
scope (str): ``"on_batch_end"`` or ``"on_batch_start"``
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
output_key (Union[List[str], str], optional): Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.

Raises:
TypeError: When keys_to_apply is not str or list.
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(order=CallbackOrder.Internal)
if input_key is not None:
if not isinstance(input_key, (list, str, int)):
raise TypeError("input key should be str or list of str.")
elif isinstance(input_key, (str, int)):
input_key = [input_key]
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
self.input_handler = self._handle_input_tuple
else:
self.input_handler = self._handle_input_dict

output_key = output_key or input_key
if output_key is not None:
if not isinstance(output_key, (list, str, int)):
raise TypeError("output key should be str or list of str.")
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(output_key, (str, int)):
self.output_handler = self._handle_output_value
output_key = [output_key]
else:
self.output_handler = self._handle_output_tuple
else:
self.output_handler = self._handle_output_dict

if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]:
self.scope = scope
else:
raise TypeError('Expected scope to be on of the ["on_batch_end", "on_batch_start"]')
self.input_key = input_key
self.output_key = output_key or self.input_key
self.output_key = output_key
self.lambda_fn = lambda_fn

@staticmethod
def _handle_input_tuple(batch, input_key):
return [batch[key] for key in input_key]

@staticmethod
def _handle_input_dict(batch, _input_key):
return batch

@staticmethod
def _handle_output_tuple(
batch: Dict[str, Any], function_output: Tuple[Any], output_keys: List[str]
) -> Dict[str, Any]:
for out_idx, output_key in enumerate(output_keys):
batch[output_key] = function_output[out_idx]
return batch

@staticmethod
def _handle_output_dict(
batch: Dict[str, Any], function_output: Dict[str, Any], _output_keys: List[str]
) -> Dict[str, Any]:
for output_key, output_value in function_output.items():
batch[output_key] = output_value
return batch

@staticmethod
def _handle_output_value(
batch: Dict[str, Any], function_output: Any, output_keys: List[str],
):
batch[output_keys[0]] = function_output
return batch

def _handle_batch(self, runner):
fn_input = self.input_handler(runner.batch, self.input_key)
fn_output = self.lambda_fn(*fn_input)

runner.batch = self.output_handler(
batch=runner.batch, function_output=fn_output, output_keys=self.output_key
)

transforms: Sequence[nn.Module] = [
item if isinstance(item, nn.Module) else REGISTRY.get_from_params(**item)
for item in transform
]
assert all(
isinstance(t, nn.Module) for t in transforms
), "`nn.Module` should be a base class for transforms"
def on_batch_start(self, runner: "IRunner") -> None:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
On batch start action.

self.transform = nn.Sequential(*transforms)
Args:
runner: runner for the experiment.
"""
if self.scope == "on_batch_start":
self._handle_batch(runner)

def on_batch_start(self, runner: "IRunner") -> None:
"""Apply transforms.
def on_batch_end(self, runner) -> None:
"""
On batch end action.

Args:
runner: сurrent runner
runner: runner for the experiment.
"""
input_batch = runner.batch[self.input_key]
output_batch = self.transform(input_batch)
runner.batch[self.output_key] = output_batch
if self.scope == "on_batch_end":
self._handle_batch(runner)


__all__ = ["BatchTransformCallback"]
4 changes: 3 additions & 1 deletion catalyst/callbacks/tests/test_transform_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def test_transform_kornia():
verbose=False,
load_best_on_end=True,
check=True,
callbacks=[BatchTransformCallback(transrorms, input_key=0)],
callbacks=[
BatchTransformCallback(lambda_fn=transrorms, scope="on_batch_start", input_key=0)
],
)

# model inference
Expand Down