diff --git a/catalyst/callbacks/metrics/__init__.py b/catalyst/callbacks/metrics/__init__.py index d2f44261e3..2e94072bdc 100644 --- a/catalyst/callbacks/metrics/__init__.py +++ b/catalyst/callbacks/metrics/__init__.py @@ -30,4 +30,4 @@ TrevskyCallback, ) -from catalyst.callbacks.metrics.sklearn import SklearnCallback +from catalyst.callbacks.metrics.scikit_learn import SklearnCallback diff --git a/catalyst/callbacks/metrics/sklearn.py b/catalyst/callbacks/metrics/scikit_learn.py similarity index 57% rename from catalyst/callbacks/metrics/sklearn.py rename to catalyst/callbacks/metrics/scikit_learn.py index 9c3be20511..380e15ab28 100644 --- a/catalyst/callbacks/metrics/sklearn.py +++ b/catalyst/callbacks/metrics/scikit_learn.py @@ -1,5 +1,6 @@ -from typing import Callable, Dict, Mapping +from typing import Any, Callable, Dict, Mapping, Union +import sklearn import torch from catalyst.callbacks.metric import FunctionalBatchMetricCallback @@ -8,16 +9,26 @@ class SklearnCallback(FunctionalBatchMetricCallback): - """@TODO: Docs.""" + """ + + Args: + keys: + metric_fn: + metric_key: + log_on_batch: + """ def __init__( self, - keys: Mapping[str, str], - metric_fn: Callable, + keys: Mapping[str, Any], + metric_fn: Union[Callable, str], metric_key: str, log_on_batch: bool = True, ): """Init.""" + if isinstance(metric_fn, str): + metric_fn = sklearn.metrics.__dict__[metric_fn] + super().__init__( metric=FunctionalBatchMetric(metric_fn=metric_fn, metric_key=metric_key), input_key=keys, @@ -28,7 +39,10 @@ def __init__( def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]: """@TODO: Docs.""" kv_inputs = {} - for key in self._keys: - kv_inputs[key] = runner.batch[self._keys[key]].cpu().detach().numpy() + for key, value in self._keys.items(): + if value in runner.batch: + kv_inputs[key] = runner.batch[value].cpu().detach().numpy() + else: + kv_inputs[key] = self._keys[key] kv_inputs["batch_size"] = runner.batch_size return kv_inputs