Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward upgrade utils #1033

Merged
merged 26 commits into from May 24, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8d741aa
func
SkafteNicki May 16, 2022
9a01d4e
add warning
SkafteNicki May 16, 2022
d7f7686
add check function
SkafteNicki May 16, 2022
33e02f4
fix tests
SkafteNicki May 16, 2022
377b0d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2022
1798d89
Merge branch 'master' into refactor/forward_upgrade_utils
SkafteNicki May 16, 2022
968fae7
Update tests/test_utilities.py
SkafteNicki May 16, 2022
3d176c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2022
c84c1b5
Apply suggestions from code review
SkafteNicki May 16, 2022
fbddd37
add test for recursive check
SkafteNicki May 16, 2022
ef4381a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2022
47bb462
Merge branch 'master' into refactor/forward_upgrade_utils
SkafteNicki May 17, 2022
facb9b3
short
Borda May 17, 2022
f453bf6
add doctest
SkafteNicki May 19, 2022
df89521
fix mypy
SkafteNicki May 19, 2022
d0364d7
update
SkafteNicki May 19, 2022
8927047
again
SkafteNicki May 19, 2022
d04075e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2022
7e00969
fix
SkafteNicki May 19, 2022
3c09f33
Merge branch 'master' into refactor/forward_upgrade_utils
Borda May 20, 2022
53b6bae
Apply suggestions from code review
Borda May 23, 2022
8abee19
Merge branch 'master' into refactor/forward_upgrade_utils
Borda May 23, 2022
097e340
Merge branch 'master' into refactor/forward_upgrade_utils
mergify[bot] May 24, 2022
9baedae
Merge branch 'master' into refactor/forward_upgrade_utils
mergify[bot] May 24, 2022
b27aaac
fix tests
SkafteNicki May 24, 2022
460aa1f
fix
SkafteNicki May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/bases/test_metric.py
Expand Up @@ -409,3 +409,15 @@ def get_memory_usage():
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory, "memory increased above base level"


@pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum])
def test_warning_on_not_set_full_state_update(metric_class):
class UnsetProperty(metric_class):
full_state_update = None

with pytest.warns(
UserWarning,
match=r"Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has.*",
):
UnsetProperty()
2 changes: 2 additions & 0 deletions tests/helpers/testers.py
Expand Up @@ -569,6 +569,7 @@ def run_differentiability_test(

class DummyMetric(Metric):
name = "Dummy"
full_state_update: Optional[bool] = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -583,6 +584,7 @@ def compute(self):

class DummyListMetric(Metric):
name = "DummyList"
full_state_update: Optional[bool] = True

def __init__(self):
super().__init__()
Expand Down
30 changes: 29 additions & 1 deletion tests/test_utilities.py
Expand Up @@ -15,7 +15,9 @@
import torch
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics import MeanSquaredError, PearsonCorrCoef
from torchmetrics.utilities import check_forward_no_full_state, rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.checks import _allclose_recursive
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce

Expand Down Expand Up @@ -126,3 +128,29 @@ def test_bincount():
# check for correctness
assert torch.allclose(res1, res2)
assert torch.allclose(res1, res3)


@pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, True), (PearsonCorrCoef, False)])
def test_check_full_state_update_fn(metric_class, expected):
"""Test that the check function works as it should."""
out = check_forward_no_full_state(
metric_class=metric_class,
input_args=dict(preds=torch.randn(100), target=torch.randn(100)),
)
assert out == expected


@pytest.mark.parametrize(
"input, expected",
[
((torch.ones(2), torch.ones(2)), True),
((torch.rand(2), torch.rand(2)), False),
(([torch.ones(2) for _ in range(2)], [torch.ones(2) for _ in range(2)]), True),
(([torch.rand(2) for _ in range(2)], [torch.rand(2) for _ in range(2)]), False),
(({f"{i}": torch.ones(2) for i in range(2)}, {f"{i}": torch.ones(2) for i in range(2)}), True),
(({f"{i}": torch.rand(2) for i in range(2)}, {f"{i}": torch.rand(2) for i in range(2)}), False),
],
)
def test_recursive_allclose(input, expected):
res = _allclose_recursive(*input)
assert res == expected
18 changes: 16 additions & 2 deletions torchmetrics/metric.py
Expand Up @@ -73,7 +73,7 @@ class Metric(Module, ABC):
__jit_unused_properties__ = ["is_differentiable"]
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
full_state_update: bool = True
full_state_update: Optional[bool] = None

def __init__(
self,
Expand Down Expand Up @@ -127,6 +127,20 @@ def __init__(
self._is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None

if self.full_state_update is None:
rank_zero_warn(
f"""Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has
not been set for this class ({self.__class__.__name__}). The property determines if ``update`` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to ``False``.
We provide an checking function
``from torchmetrics.utilities import check_forward_no_full_state``
that can be used to check if the ``full_state_update=True`` (old and potential slower behaviour,
default for now) or if ``full_state_update=False`` can be used safely.
""",
Borda marked this conversation as resolved.
Show resolved Hide resolved
UserWarning,
)

@property
def _update_called(self) -> bool:
# Needed for lightning integration
Expand Down Expand Up @@ -216,7 +230,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"HINT: Did you forget to call ``unsync`` ?."
)

if self.full_state_update or self.dist_sync_on_step:
if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
self._forward_cache = self._forward_full_state_update(*args, **kwargs)
else:
self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/utilities/__init__.py
@@ -1,3 +1,4 @@
from torchmetrics.utilities.checks import check_forward_no_full_state # noqa: F401
from torchmetrics.utilities.data import apply_to_collection # noqa: F401
from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401
from torchmetrics.utilities.prints import _future_warning, rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401
119 changes: 118 additions & 1 deletion torchmetrics/utilities/checks.py
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from time import perf_counter
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -604,3 +605,119 @@ def _check_retrieval_target_and_prediction_types(
preds = preds.float()

return preds.flatten(), target.flatten()


def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool:
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(res1, Tensor):
return torch.allclose(res1, res2, atol=atol)
elif isinstance(res1, str):
return res1 == res2
elif isinstance(res1, Sequence):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2))
elif isinstance(res1, Mapping):
return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys())
else:
return res1 == res2
Borda marked this conversation as resolved.
Show resolved Hide resolved


def check_forward_no_full_state(
metric_class, # type: ignore
init_args: Dict[str, Any] = {},
input_args: Dict[str, Any] = {},
num_update_to_compare: Sequence[int] = [10, 100, 1000],
reps: int = 5,
) -> bool:
"""Utility function for checking if the new ``full_state_update`` property can safely be set to ``False`` which
will for most metrics results in a speedup when using ``forward``.

Args:
metric_class: metric class object that should be checked
init_args: dict containing arguments for initializing the metric class
input_args: dict containing arguments to pass to ``forward``
num_update_to_compare: if we successfully detech that the flag is safe to set to ``False``
we will run some speedup test. This arg should be a list of integers for how many
steps to compare over.
reps: number of repetitions of speedup test

Example (states in ``update`` are independent, save to set ``full_state_update=False``)
>>> from torchmetrics import ConfusionMatrix
>>> check_forward_no_full_state(
... ConfusionMatrix,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... ) # doctest: +ELLIPSIS
Full state for 10 steps took: ...
Partial state for 10 steps took: ...
Full state for 100 steps took: ...
Partial state for 100 steps took: ...
Full state for 1000 steps took: ...
Partial state for 1000 steps took: ...
True

Example (states in ``update`` are dependend meaning that ``full_state_update=True``):
>>> from torchmetrics import ConfusionMatrix
>>> class MyMetric(ConfusionMatrix):
... def update(self, preds, target):
... super().update(preds, target)
... # by construction make future states dependent on prior states
... if self.confmat.sum() > 20:
... self.reset()
>>> check_forward_no_full_state(
... MyMetric,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... )
False
"""

class FullState(metric_class):
full_state_update = True

class PartState(metric_class):
full_state_update = False

fullstate = FullState(**init_args)
partstate = PartState(**init_args)

equal = True
for _ in range(num_update_to_compare[0]):
out1 = fullstate(**input_args)
try: # if it fails, the code most likely need access to the full state
out2 = partstate(**input_args)
except RuntimeError:
equal = False
break
equal = equal & _allclose_recursive(out1, out2)

res1 = fullstate.compute()
try: # if it fails, the code most likely need access to the full state
res2 = partstate.compute()
except RuntimeError:
equal = False
equal = equal & _allclose_recursive(res1, res2)

if not equal: # we can stop early because the results did not match
return False

# Do timings
res = torch.zeros(2, len(num_update_to_compare), reps)
for i, metric in enumerate([fullstate, partstate]):
for j, t in enumerate(num_update_to_compare):
for r in range(reps):
start = perf_counter()
for _ in range(t):
_ = metric(**input_args)
end = perf_counter()
res[i, j, r] = end - start
metric.reset()

mean = torch.mean(res, -1)
std = torch.std(res, -1)

for t in range(len(num_update_to_compare)):
print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}")
print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}")

return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading