Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Registry to batch transform callback #1209

Merged
merged 25 commits into from May 31, 2021
Merged
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fc7dc0c
add registry
elephantmipt May 10, 2021
adcbaf5
fix docs
elephantmipt May 10, 2021
c30ad5f
fix docs and codestyle. Add example
elephantmipt May 10, 2021
63cba6c
fix line length :)
elephantmipt May 10, 2021
d97da3a
add registry.py
elephantmipt May 12, 2021
05e96fa
codestyle
elephantmipt May 12, 2021
48ca1d9
fix
elephantmipt May 12, 2021
909c6dc
add partial and kwargs
elephantmipt May 30, 2021
38007e7
Fix data: imports (#1211)
Scitator May 16, 2021
0af8667
Change layerwise to layerwise_params (#1210)
MrNightSky May 16, 2021
b017975
[WIP] faq docs (#1202)
ditwoo May 17, 2021
90c6f63
Update CHANGELOG.md (#1216)
MrNightSky May 18, 2021
3f2e77b
fix: to_numpy wrapper for `AdditiveValueMetric` added (#1214)
bagxi May 18, 2021
da557aa
tests update + readme (#1215)
Scitator May 18, 2021
be20b15
Fix compute method (#1206)
Dokholyan May 22, 2021
8577120
Update ddt.rst (#1217)
riohib May 22, 2021
eb99422
docker update (#1218)
Scitator May 24, 2021
a734897
docs update (#1219)
Scitator May 25, 2021
0428658
Updated NeptuneLogger docstrings (#1223)
May 28, 2021
b655ff5
Merge branch 'master' into registry_to_batch_transform
elephantmipt May 30, 2021
e8a3036
remove unnecessary thing
elephantmipt May 30, 2021
4b12447
Update catalyst/callbacks/batch_transform.py
Scitator May 30, 2021
7d80e8f
Apply suggestions from code review
Scitator May 30, 2021
991329f
Apply suggestions from code review
Scitator May 31, 2021
45d6600
Merge branch 'master' into registry_to_batch_transform
Scitator May 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 19 additions & 4 deletions catalyst/callbacks/batch_transform.py
@@ -1,6 +1,7 @@
from typing import Callable, List, Union

from catalyst.core import Callback, CallbackOrder, IRunner
from catalyst.registry import REGISTRY


def _tuple_wrapper(transform: Callable):
Expand All @@ -17,7 +18,7 @@ class BatchTransformCallback(Callback):
Preprocess your batch with specified function.

Args:
transform (Callable): Function to apply.
transform (Callable, str): Function to apply. If string will get function from registry.
scope (str): ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
Expand Down Expand Up @@ -165,12 +166,22 @@ def __len__(self):
)
],
)
.. code-block:: yaml

...
callbacks:
transform:
_target_: BatchTransformCallback
transform: catalyst.ToTensor
scope: on_batch_start
input_key: features


"""

def __init__(
self,
transform: Callable,
transform: Union[Callable, str],
scope: str,
input_key: Union[List[str], str] = None,
output_key: Union[List[str], str] = None,
Expand All @@ -179,7 +190,7 @@ def __init__(
Preprocess your batch with specified function.

Args:
transform (Callable): Function to apply.
transform (Callable, str): Function to apply. If string will get function from registry.
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
scope (str): ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
Expand All @@ -206,7 +217,9 @@ def __init__(
output_key = output_key or input_key
if output_key is not None:
if input_key is None:
raise TypeError("You should define input_key in " "case if output_key is not None")
raise TypeError(
"You should define input_key in " "case if output_key is not None"
)
if not isinstance(output_key, (list, str)):
raise TypeError("output key should be str or a list of str.")
if isinstance(output_key, str):
Expand All @@ -219,6 +232,8 @@ def __init__(
raise TypeError('Expected scope to be on of the ["on_batch_end", "on_batch_start"]')
self.input_key = input_key
self.output_key = output_key
if isinstance(transform, str):
transform = REGISTRY.get(transform)
self.transform = transform

def _handle_value(self, runner):
Expand Down