Skip to content

Commit

Permalink
Fix representation of tuples in approx
Browse files Browse the repository at this point in the history
Closes #9917
  • Loading branch information
ZachOBrien committed Jun 14, 2022
1 parent bb94e83 commit 96412d1
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -361,5 +361,6 @@ Yoav Caspi
Yuval Shimon
Zac Hatfield-Dodds
Zachary Kneupper
Zachary OBrien
Zoltán Máté
Zsolt Cserna
1 change: 1 addition & 0 deletions changelog/9917.bugfix.rst
@@ -0,0 +1 @@
Fixed string representation for :func:`pytest.approx` when used to compare tuples.
20 changes: 12 additions & 8 deletions src/_pytest/python_api.py
Expand Up @@ -133,9 +133,11 @@ def _check_type(self) -> None:
# raise if there are any non-numeric elements in the sequence.


def _recursive_list_map(f, x):
if isinstance(x, list):
return [_recursive_list_map(f, xi) for xi in x]
def _recursive_sequence_map(f, x):
"""Recursively map a function over a sequence of arbitary depth"""
if isinstance(x, (list, tuple)):
seq_type = type(x)
return seq_type(_recursive_sequence_map(f, xi) for xi in x)
else:
return f(x)

Expand All @@ -144,7 +146,9 @@ class ApproxNumpy(ApproxBase):
"""Perform approximate comparisons where the expected value is numpy array."""

def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
return f"approx({list_scalars!r})"

def _repr_compare(self, other_side: "ndarray") -> List[str]:
Expand All @@ -164,7 +168,7 @@ def get_value_from_nested_list(
return value

np_array_shape = self.expected.shape
approx_side_as_list = _recursive_list_map(
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)

Expand All @@ -179,7 +183,7 @@ def get_value_from_nested_list(
max_rel_diff = -math.inf
different_ids = []
for index in itertools.product(*(range(i) for i in np_array_shape)):
approx_value = get_value_from_nested_list(approx_side_as_list, index)
approx_value = get_value_from_nested_list(approx_side_as_seq, index)
other_value = get_value_from_nested_list(other_side, index)
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
Expand All @@ -194,7 +198,7 @@ def get_value_from_nested_list(
(
str(index),
str(get_value_from_nested_list(other_side, index)),
str(get_value_from_nested_list(approx_side_as_list, index)),
str(get_value_from_nested_list(approx_side_as_seq, index)),
)
for index in different_ids
]
Expand Down Expand Up @@ -326,7 +330,7 @@ def _repr_compare(self, other_side: Sequence[float]) -> List[str]:
f"Lengths: {len(self.expected)} and {len(other_side)}",
]

approx_side_as_map = _recursive_list_map(self._approx_scalar, self.expected)
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)

number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf
Expand Down
42 changes: 42 additions & 0 deletions testing/python/approx.py
Expand Up @@ -2,12 +2,14 @@
from contextlib import contextmanager
from decimal import Decimal
from fractions import Fraction
from math import sqrt
from operator import eq
from operator import ne
from typing import Optional

import pytest
from _pytest.pytester import Pytester
from _pytest.python_api import _recursive_sequence_map
from pytest import approx

inf, nan = float("inf"), float("nan")
Expand Down Expand Up @@ -133,6 +135,18 @@ def test_error_messages_native_dtypes(self, assert_approx_raises_regex):
],
)

assert_approx_raises_regex(
(1, 2.2, 4),
(1, 3.2, 4),
[
r" comparison failed. Mismatched elements: 1 / 3:",
rf" Max absolute difference: {SOME_FLOAT}",
rf" Max relative difference: {SOME_FLOAT}",
r" Index \| Obtained\s+\| Expected ",
rf" 1 \| {SOME_FLOAT} \| {SOME_FLOAT} ± {SOME_FLOAT}",
],
)

# Specific test for comparison with 0.0 (relative diff will be 'inf')
assert_approx_raises_regex(
[0.0],
Expand Down Expand Up @@ -878,3 +892,31 @@ def test_allow_ordered_sequences_only(self) -> None:
"""pytest.approx() should raise an error on unordered sequences (#9692)."""
with pytest.raises(TypeError, match="only supports ordered sequences"):
assert {1, 2, 3} == approx({1, 2, 3})


class TestRecursiveSequenceMap:
def test_map_over_scalar(self):
assert _recursive_sequence_map(sqrt, 16) == 4

def test_map_over_empty_list(self):
assert _recursive_sequence_map(sqrt, []) == []

def test_map_over_list(self):
assert _recursive_sequence_map(sqrt, [4, 16, 25, 676]) == [2, 4, 5, 26]

def test_map_over_tuple(self):
assert _recursive_sequence_map(sqrt, (4, 16, 25, 676)) == (2, 4, 5, 26)

def test_map_over_nested_lists(self):
assert _recursive_sequence_map(sqrt, [4, [25, 64], [[49]]]) == [
2,
[5, 8],
[[7]],
]

def test_map_over_mixed_sequence(self):
assert _recursive_sequence_map(sqrt, [4, (25, 64), [(49)]]) == [
2,
(5, 8),
[(7)],
]

0 comments on commit 96412d1

Please sign in to comment.