diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 55342eeb2d6..5b2c866f9f2 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -409,3 +409,15 @@ def get_memory_usage(): metric.update(x.sum()) memory = get_memory_usage() assert base_memory_level >= memory, "memory increased above base level" + + +@pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum]) +def test_warning_on_not_set_full_state_update(metric_class): + class UnsetProperty(metric_class): + full_state_update = None + + with pytest.warns( + UserWarning, + match="Torchmetrics v0.9 introduced a new argument class property called.*", + ): + UnsetProperty() diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index a75d5174b2a..7bc0c3cb2e5 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -569,6 +569,7 @@ def run_differentiability_test( class DummyMetric(Metric): name = "Dummy" + full_state_update: Optional[bool] = True def __init__(self, **kwargs): super().__init__(**kwargs) @@ -583,6 +584,7 @@ def compute(self): class DummyListMetric(Metric): name = "DummyList" + full_state_update: Optional[bool] = True def __init__(self): super().__init__() diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 40dd37fe1e3..9cafb5b9d03 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -15,7 +15,9 @@ import torch from torch import tensor -from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn +from torchmetrics import MeanSquaredError, PearsonCorrCoef +from torchmetrics.utilities import check_forward_no_full_state, rank_zero_debug, rank_zero_info, rank_zero_warn +from torchmetrics.utilities.checks import _allclose_recursive from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -126,3 +128,29 @@ def test_bincount(): # check for correctness assert torch.allclose(res1, res2) assert torch.allclose(res1, res3) + + +@pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, True), (PearsonCorrCoef, False)]) +def test_check_full_state_update_fn(metric_class, expected): + """Test that the check function works as it should.""" + out = check_forward_no_full_state( + metric_class=metric_class, + input_args=dict(preds=torch.randn(100), target=torch.randn(100)), + ) + assert out == expected + + +@pytest.mark.parametrize( + "input, expected", + [ + ((torch.ones(2), torch.ones(2)), True), + ((torch.rand(2), torch.rand(2)), False), + (([torch.ones(2) for _ in range(2)], [torch.ones(2) for _ in range(2)]), True), + (([torch.rand(2) for _ in range(2)], [torch.rand(2) for _ in range(2)]), False), + (({f"{i}": torch.ones(2) for i in range(2)}, {f"{i}": torch.ones(2) for i in range(2)}), True), + (({f"{i}": torch.rand(2) for i in range(2)}, {f"{i}": torch.rand(2) for i in range(2)}), False), + ], +) +def test_recursive_allclose(input, expected): + res = _allclose_recursive(*input) + assert res == expected diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index a6987f98d4f..ec857708e97 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -73,7 +73,7 @@ class Metric(Module, ABC): __jit_unused_properties__ = ["is_differentiable"] is_differentiable: Optional[bool] = None higher_is_better: Optional[bool] = None - full_state_update: bool = True + full_state_update: Optional[bool] = None def __init__( self, @@ -127,6 +127,20 @@ def __init__( self._is_synced = False self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None + if self.full_state_update is None: + rank_zero_warn( + f"""Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has + not been set for this class ({self.__class__.__name__}). The property determines if `update` by + default needs access to the full metric state. If this is not the case, significant speedups can be + achieved and we recommend setting this to `False`. + We provide an checking function + `from torchmetrics.utilities import check_forward_no_full_state` + that can be used to check if the `full_state_update=True` (old and potential slower behaviour, + default for now) or if `full_state_update=False` can be used safely. + """, + UserWarning, + ) + @property def _update_called(self) -> bool: # Needed for lightning integration @@ -216,7 +230,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: "HINT: Did you forget to call ``unsync`` ?." ) - if self.full_state_update or self.dist_sync_on_step: + if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: self._forward_cache = self._forward_full_state_update(*args, **kwargs) else: self._forward_cache = self._forward_reduce_state_update(*args, **kwargs) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index 5049a9d6901..3ec9e96b675 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,4 @@ +from torchmetrics.utilities.checks import check_forward_no_full_state # noqa: F401 from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 from torchmetrics.utilities.prints import _future_warning, rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index f688a873c80..16cb4a267c0 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from time import perf_counter +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, no_type_check import torch from torch import Tensor @@ -604,3 +605,119 @@ def _check_retrieval_target_and_prediction_types( preds = preds.float() return preds.flatten(), target.flatten() + + +def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: + """Utility function for recursively asserting that two results are within a certain tolerance.""" + # single output compare + if isinstance(res1, Tensor): + return torch.allclose(res1, res2, atol=atol) + elif isinstance(res1, str): + return res1 == res2 + elif isinstance(res1, Sequence): + return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2)) + elif isinstance(res1, Mapping): + return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys()) + return res1 == res2 + + +@no_type_check +def check_forward_no_full_state( + metric_class, + init_args: Dict[str, Any] = {}, + input_args: Dict[str, Any] = {}, + num_update_to_compare: Sequence[int] = [10, 100, 1000], + reps: int = 5, +) -> bool: + """Utility function for checking if the new ``full_state_update`` property can safely be set to ``False`` which + will for most metrics results in a speedup when using ``forward``. + + Args: + metric_class: metric class object that should be checked + init_args: dict containing arguments for initializing the metric class + input_args: dict containing arguments to pass to ``forward`` + num_update_to_compare: if we successfully detech that the flag is safe to set to ``False`` + we will run some speedup test. This arg should be a list of integers for how many + steps to compare over. + reps: number of repetitions of speedup test + + Example (states in ``update`` are independent, save to set ``full_state_update=False``) + >>> from torchmetrics import ConfusionMatrix + >>> check_forward_no_full_state( + ... ConfusionMatrix, + ... init_args = {'num_classes': 3}, + ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, + ... ) # doctest: +ELLIPSIS + Full state for 10 steps took: ... + Partial state for 10 steps took: ... + Full state for 100 steps took: ... + Partial state for 100 steps took: ... + Full state for 1000 steps took: ... + Partial state for 1000 steps took: ... + True + + Example (states in ``update`` are dependend meaning that ``full_state_update=True``): + >>> from torchmetrics import ConfusionMatrix + >>> class MyMetric(ConfusionMatrix): + ... def update(self, preds, target): + ... super().update(preds, target) + ... # by construction make future states dependent on prior states + ... if self.confmat.sum() > 20: + ... self.reset() + >>> check_forward_no_full_state( + ... MyMetric, + ... init_args = {'num_classes': 3}, + ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, + ... ) + False + """ + + class FullState(metric_class): + full_state_update = True + + class PartState(metric_class): + full_state_update = False + + fullstate = FullState(**init_args) + partstate = PartState(**init_args) + + equal = True + for _ in range(num_update_to_compare[0]): + out1 = fullstate(**input_args) + try: # if it fails, the code most likely need access to the full state + out2 = partstate(**input_args) + except RuntimeError: + equal = False + break + equal = equal & _allclose_recursive(out1, out2) + + res1 = fullstate.compute() + try: # if it fails, the code most likely need access to the full state + res2 = partstate.compute() + except RuntimeError: + equal = False + equal = equal & _allclose_recursive(res1, res2) + + if not equal: # we can stop early because the results did not match + return False + + # Do timings + res = torch.zeros(2, len(num_update_to_compare), reps) + for i, metric in enumerate([fullstate, partstate]): + for j, t in enumerate(num_update_to_compare): + for r in range(reps): + start = perf_counter() + for _ in range(t): + _ = metric(**input_args) + end = perf_counter() + res[i, j, r] = end - start + metric.reset() + + mean = torch.mean(res, -1) + std = torch.std(res, -1) + + for t in range(len(num_update_to_compare)): + print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}") + print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}") + + return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading