diff --git a/AUTHORS b/AUTHORS index 61c3901beaa..6a88d2bbb71 100644 --- a/AUTHORS +++ b/AUTHORS @@ -361,5 +361,6 @@ Yoav Caspi Yuval Shimon Zac Hatfield-Dodds Zachary Kneupper +Zachary OBrien Zoltán Máté Zsolt Cserna diff --git a/changelog/9917.bugfix.rst b/changelog/9917.bugfix.rst new file mode 100644 index 00000000000..ac0616d2ee7 --- /dev/null +++ b/changelog/9917.bugfix.rst @@ -0,0 +1 @@ +Fixed string representation for :func:`pytest.approx` when used to compare tuples. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index ba3b48783ce..426b93ac2d4 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -133,9 +133,11 @@ def _check_type(self) -> None: # raise if there are any non-numeric elements in the sequence. -def _recursive_list_map(f, x): - if isinstance(x, list): - return [_recursive_list_map(f, xi) for xi in x] +def _recursive_sequence_map(f, x): + """Recursively map a function over a sequence of arbitary depth""" + if isinstance(x, (list, tuple)): + seq_type = type(x) + return seq_type(_recursive_sequence_map(f, xi) for xi in x) else: return f(x) @@ -144,7 +146,9 @@ class ApproxNumpy(ApproxBase): """Perform approximate comparisons where the expected value is numpy array.""" def __repr__(self) -> str: - list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist()) + list_scalars = _recursive_sequence_map( + self._approx_scalar, self.expected.tolist() + ) return f"approx({list_scalars!r})" def _repr_compare(self, other_side: "ndarray") -> List[str]: @@ -164,7 +168,7 @@ def get_value_from_nested_list( return value np_array_shape = self.expected.shape - approx_side_as_list = _recursive_list_map( + approx_side_as_seq = _recursive_sequence_map( self._approx_scalar, self.expected.tolist() ) @@ -179,7 +183,7 @@ def get_value_from_nested_list( max_rel_diff = -math.inf different_ids = [] for index in itertools.product(*(range(i) for i in np_array_shape)): - approx_value = get_value_from_nested_list(approx_side_as_list, index) + approx_value = get_value_from_nested_list(approx_side_as_seq, index) other_value = get_value_from_nested_list(other_side, index) if approx_value != other_value: abs_diff = abs(approx_value.expected - other_value) @@ -194,7 +198,7 @@ def get_value_from_nested_list( ( str(index), str(get_value_from_nested_list(other_side, index)), - str(get_value_from_nested_list(approx_side_as_list, index)), + str(get_value_from_nested_list(approx_side_as_seq, index)), ) for index in different_ids ] @@ -326,7 +330,7 @@ def _repr_compare(self, other_side: Sequence[float]) -> List[str]: f"Lengths: {len(self.expected)} and {len(other_side)}", ] - approx_side_as_map = _recursive_list_map(self._approx_scalar, self.expected) + approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected) number_of_elements = len(approx_side_as_map) max_abs_diff = -math.inf diff --git a/testing/python/approx.py b/testing/python/approx.py index 7b4fbad156e..6acb466ffb1 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -2,12 +2,14 @@ from contextlib import contextmanager from decimal import Decimal from fractions import Fraction +from math import sqrt from operator import eq from operator import ne from typing import Optional import pytest from _pytest.pytester import Pytester +from _pytest.python_api import _recursive_sequence_map from pytest import approx inf, nan = float("inf"), float("nan") @@ -133,6 +135,18 @@ def test_error_messages_native_dtypes(self, assert_approx_raises_regex): ], ) + assert_approx_raises_regex( + (1, 2.2, 4), + (1, 3.2, 4), + [ + r" comparison failed. Mismatched elements: 1 / 3:", + rf" Max absolute difference: {SOME_FLOAT}", + rf" Max relative difference: {SOME_FLOAT}", + r" Index \| Obtained\s+\| Expected ", + rf" 1 \| {SOME_FLOAT} \| {SOME_FLOAT} ± {SOME_FLOAT}", + ], + ) + # Specific test for comparison with 0.0 (relative diff will be 'inf') assert_approx_raises_regex( [0.0], @@ -878,3 +892,31 @@ def test_allow_ordered_sequences_only(self) -> None: """pytest.approx() should raise an error on unordered sequences (#9692).""" with pytest.raises(TypeError, match="only supports ordered sequences"): assert {1, 2, 3} == approx({1, 2, 3}) + + +class TestRecursiveSequenceMap: + def test_map_over_scalar(self): + assert _recursive_sequence_map(sqrt, 16) == 4 + + def test_map_over_empty_list(self): + assert _recursive_sequence_map(sqrt, []) == [] + + def test_map_over_list(self): + assert _recursive_sequence_map(sqrt, [4, 16, 25, 676]) == [2, 4, 5, 26] + + def test_map_over_tuple(self): + assert _recursive_sequence_map(sqrt, (4, 16, 25, 676)) == (2, 4, 5, 26) + + def test_map_over_nested_lists(self): + assert _recursive_sequence_map(sqrt, [4, [25, 64], [[49]]]) == [ + 2, + [5, 8], + [[7]], + ] + + def test_map_over_mixed_sequence(self): + assert _recursive_sequence_map(sqrt, [4, (25, 64), [(49)]]) == [ + 2, + (5, 8), + [(7)], + ]