Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to disable automatic syncronization #1107

Merged
merged 10 commits into from Jun 30, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added global option `sync_on_compute` to disable automatic syncronization when `compute` is called ([#1107](https://github.dev/Lightning-AI/metrics/pull/1107))


-
Expand Down
20 changes: 14 additions & 6 deletions src/torchmetrics/metric.py
Expand Up @@ -65,9 +65,11 @@ class Metric(Module, ABC):

- compute_on_cpu: If metric state should be stored on CPU during computations. Only works
for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``
- process_group: The process group on which the synchronization is called
- dist_sync_fn: function that performs the allgather option on the metric state
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: function that performs the allgather option on the metric state. Default is an
custom implementation that calls ``torch.distributed.all_gather`` internally.
- sync_on_compute: If metric state should synchonize when ``compute`` is called. Default is ``True``-
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

__jit_ignored_attributes__ = ["device"]
Expand Down Expand Up @@ -108,14 +110,20 @@ def __init__(
f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}"
)

self.sync_on_compute = kwargs.pop("sync_on_compute", True)
if not isinstance(self.sync_on_compute, bool):
raise ValueError(
f"Expected keyword argument `sync_on_compute` to be an `bool` but got {self.sync_on_compute}"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)

# initialize
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update) # type: ignore
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
self._computed = None
self._forward_cache = None
self._update_count = 0
self._to_sync = True
self._to_sync = self.sync_on_compute
self._should_unsync = True
self._enable_grad = False

Expand Down Expand Up @@ -272,7 +280,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
# restore context
self._is_synced = False
self._should_unsync = True
self._to_sync = True
self._to_sync = self.sync_on_compute
self._computed = None
self._enable_grad = False
self.compute_on_cpu = _temp_compute_on_cpu
Expand Down Expand Up @@ -309,7 +317,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
# restore context
self._is_synced = False
self._should_unsync = True
self._to_sync = True
self._to_sync = self.sync_on_compute
self._computed = None
self._enable_grad = False
self.compute_on_cpu = _temp_compute_on_cpu
Expand Down
32 changes: 31 additions & 1 deletion tests/unittests/bases/test_ddp.py
Expand Up @@ -23,7 +23,7 @@
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyMetric, DummyMetricSum, setup_ddp
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum, setup_ddp

seed_all(42)

Expand Down Expand Up @@ -239,3 +239,33 @@ def test_state_dict_is_synced(tmpdir):
"""This test asserts that metrics are synced while creating the state dict but restored after to continue
accumulation."""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)


def _test_sync_on_compute_tensor_state(rank, worldsize, sync_on_compute):
setup_ddp(rank, worldsize)
dummy = DummyMetricSum(sync_on_compute=sync_on_compute)
dummy.update(tensor(rank + 1))
val = dummy.compute()
if sync_on_compute:
assert val == 3
else:
assert val == rank + 1


def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute):
setup_ddp(rank, worldsize)
dummy = DummyListMetric(sync_on_compute=sync_on_compute)
dummy.x.append(tensor(rank + 1))
val = dummy.compute()
if sync_on_compute:
assert val == [1, 2]
else:
assert val == [rank + 1]


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("sync_on_compute", [True, False])
@pytest.mark.parametrize("test_func", [_test_sync_on_compute_list_state, _test_sync_on_compute_tensor_state])
def test_sync_on_compute(sync_on_compute, test_func):
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(test_func, args=(2, sync_on_compute), nprocs=2)
4 changes: 2 additions & 2 deletions tests/unittests/helpers/testers.py
Expand Up @@ -589,8 +589,8 @@ class DummyListMetric(Metric):
name = "DummyList"
full_state_update: Optional[bool] = True

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", [], dist_reduce_fx="cat")

def update(self):
Expand Down