diff --git a/catalyst/callbacks/batch_transform.py b/catalyst/callbacks/batch_transform.py index 08d1a34f08..28d532c238 100644 --- a/catalyst/callbacks/batch_transform.py +++ b/catalyst/callbacks/batch_transform.py @@ -1,6 +1,8 @@ -from typing import Callable, List, Union +from typing import Callable, List, Union, Dict, Any +from functools import partial from catalyst.core import Callback, CallbackOrder, IRunner +from catalyst.registry import REGISTRY def _tuple_wrapper(transform: Callable): @@ -17,7 +19,7 @@ class BatchTransformCallback(Callback): Preprocess your batch with specified function. Args: - transform (Callable): Function to apply. + transform (Callable, str): Function to apply. If string will get function from registry. scope (str): ``"on_batch_end"`` (post-processing model output) or ``"on_batch_start"`` (pre-processing model input). input_key (Union[List[str], str], optional): Keys in batch dict to apply function. @@ -25,6 +27,7 @@ class BatchTransformCallback(Callback): output_key (Union[List[str], str], optional): Keys for output. If None then will apply function inplace to ``keys_to_apply``. Defaults to ``None``. + transform_kwargs (Dict[str, Any]): Kwargs for transform. Raises: TypeError: When keys is not str or a list. @@ -64,8 +67,8 @@ class BatchTransformCallback(Callback): num_epochs=3, verbose=True, callbacks=[ - dl.LambdaPreprocessCallback( - input_key="logits", output_key="scores", transform=torch.sigmoid + dl.BatchTransformCallback( + input_key="logits", output_key="scores", transform="F.sigmoid", ), dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss" @@ -165,21 +168,33 @@ def __len__(self): ) ], ) + .. code-block:: yaml + + ... + callbacks: + transform: + _target_: BatchTransformCallback + transform: catalyst.ToTensor + scope: on_batch_start + input_key: features + """ def __init__( self, - transform: Callable, + transform: Union[Callable, str], scope: str, input_key: Union[List[str], str] = None, output_key: Union[List[str], str] = None, + transform_kwargs: Dict[str, Any] = None, ): """ Preprocess your batch with specified function. Args: - transform (Callable): Function to apply. + transform (Callable, str): Function to apply. + If string will get function from registry. scope (str): ``"on_batch_end"`` (post-processing model output) or ``"on_batch_start"`` (pre-processing model input). input_key (Union[List[str], str], optional): Keys in batch dict to apply function. @@ -187,13 +202,16 @@ def __init__( output_key (Union[List[str], str], optional): Keys for output. If None then will apply function inplace to ``keys_to_apply``. Defaults to ``None``. - + transform_kwargs (Dict[str, Any]): Kwargs for transform. Raises: TypeError: When keys is not str or a list. When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``. """ super().__init__(order=CallbackOrder.Internal) - + if isinstance(transform, str): + transform = REGISTRY.get(transform) + if transform_kwargs is not None: + transform = partial(transform, **transform_kwargs) if input_key is not None: if not isinstance(input_key, (list, str)): raise TypeError("input key should be str or a list of str.") diff --git a/catalyst/registry.py b/catalyst/registry.py index df258f8f36..caa05b7e8b 100644 --- a/catalyst/registry.py +++ b/catalyst/registry.py @@ -144,4 +144,13 @@ def _loggers_loader(r: registry.Registry): REGISTRY.late_add(_loggers_loader) +def _torch_functional_loader(r: registry.Registry): + import torch.nn.functional as F + + r.add_from_module(F, ["F."]) + + +REGISTRY.late_add(_torch_functional_loader) + + __all__ = ["REGISTRY"] diff --git a/tests/pipelines/test_multilabel_classification.py b/tests/pipelines/test_multilabel_classification.py index 74ab673f58..5579e5e7d7 100644 --- a/tests/pipelines/test_multilabel_classification.py +++ b/tests/pipelines/test_multilabel_classification.py @@ -34,10 +34,7 @@ def train_experiment(device, engine=None): ) callbacks = [ dl.BatchTransformCallback( - transform=torch.sigmoid, - scope="on_batch_end", - input_key="logits", - output_key="scores", + transform="F.sigmoid", scope="on_batch_end", input_key="logits", output_key="scores", ), dl.MultilabelAccuracyCallback(input_key="scores", target_key="targets", threshold=0.5), dl.MultilabelPrecisionRecallF1SupportCallback(