diff --git a/CHANGELOG.md b/CHANGELOG.md index 93fc6fdf72a..5ae89562ec1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107)) - diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2fa4eacb218..60fbaf86a2f 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -65,9 +65,11 @@ class Metric(Module, ABC): - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states. - - dist_sync_on_step: If metric state should synchronize on ``forward()`` - - process_group: The process group on which the synchronization is called - - dist_sync_fn: function that performs the allgather option on the metric state + - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False`` + - process_group: The process group on which the synchronization is called. Default is the world. + - dist_sync_fn: function that performs the allgather option on the metric state. Default is an + custom implementation that calls ``torch.distributed.all_gather`` internally. + - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``- """ __jit_ignored_attributes__ = ["device"] @@ -108,6 +110,12 @@ def __init__( f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}" ) + self.sync_on_compute = kwargs.pop("sync_on_compute", True) + if not isinstance(self.sync_on_compute, bool): + raise ValueError( + f"Expected keyword argument `sync_on_compute` to be a `bool` but got {self.sync_on_compute}" + ) + # initialize self._update_signature = inspect.signature(self.update) self.update: Callable = self._wrap_update(self.update) # type: ignore @@ -115,7 +123,7 @@ def __init__( self._computed = None self._forward_cache = None self._update_count = 0 - self._to_sync = True + self._to_sync = self.sync_on_compute self._should_unsync = True self._enable_grad = False @@ -272,7 +280,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: # restore context self._is_synced = False self._should_unsync = True - self._to_sync = True + self._to_sync = self.sync_on_compute self._computed = None self._enable_grad = False self.compute_on_cpu = _temp_compute_on_cpu @@ -309,7 +317,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: # restore context self._is_synced = False self._should_unsync = True - self._to_sync = True + self._to_sync = self.sync_on_compute self._computed = None self._enable_grad = False self.compute_on_cpu = _temp_compute_on_cpu diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 1b83f251983..a4838a1cffd 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -23,7 +23,7 @@ from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests.helpers import seed_all -from unittests.helpers.testers import DummyMetric, DummyMetricSum, setup_ddp +from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum, setup_ddp seed_all(42) @@ -111,6 +111,8 @@ def _test_non_contiguous_tensors(rank, worldsize): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): + full_state_update = True + def __init__(self): super().__init__() self.add_state("x", default=[], dist_reduce_fx=None) @@ -136,6 +138,8 @@ def _test_state_dict_is_synced(rank, worldsize, tmpdir): setup_ddp(rank, worldsize) class DummyCatMetric(Metric): + full_state_update = True + def __init__(self): super().__init__() self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum) @@ -239,3 +243,33 @@ def test_state_dict_is_synced(tmpdir): """This test asserts that metrics are synced while creating the state dict but restored after to continue accumulation.""" torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2) + + +def _test_sync_on_compute_tensor_state(rank, worldsize, sync_on_compute): + setup_ddp(rank, worldsize) + dummy = DummyMetricSum(sync_on_compute=sync_on_compute) + dummy.update(tensor(rank + 1)) + val = dummy.compute() + if sync_on_compute: + assert val == 3 + else: + assert val == rank + 1 + + +def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute): + setup_ddp(rank, worldsize) + dummy = DummyListMetric(sync_on_compute=sync_on_compute) + dummy.update(tensor(rank + 1)) + val = dummy.compute() + if sync_on_compute: + assert torch.allclose(val, tensor([1, 2])) + else: + assert val == [tensor(rank + 1)] + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize("sync_on_compute", [True, False]) +@pytest.mark.parametrize("test_func", [_test_sync_on_compute_list_state, _test_sync_on_compute_tensor_state]) +def test_sync_on_compute(sync_on_compute, test_func): + """Test that syncronization of states can be enabled and disabled for compute.""" + torch.multiprocessing.spawn(test_func, args=(2, sync_on_compute), nprocs=2) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index d37e8d4d78a..92008ef5a4f 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -589,15 +589,15 @@ class DummyListMetric(Metric): name = "DummyList" full_state_update: Optional[bool] = True - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) self.add_state("x", [], dist_reduce_fx="cat") - def update(self): - pass + def update(self, x=torch.tensor(1)): + self.x.append(x) def compute(self): - pass + return self.x class DummyMetricSum(DummyMetric):