diff --git a/changelog/8132.bugfix.rst b/changelog/8132.bugfix.rst new file mode 100644 index 00000000000..811e14302d0 --- /dev/null +++ b/changelog/8132.bugfix.rst @@ -0,0 +1,6 @@ +Fixed regression in ``approx``: in 6.2.0 ``approx`` no longer raises +``TypeError`` when dealing with non-numeric types, falling back to normal comparison, +however the check was done using ``isinstance`` which left out types which implemented +the necessary methods for ``approx`` to work, such as tensorflow's ``DeviceArray``. + +The code has been changed to check for the necessary methods to accommodate those cases. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index bae2076892b..8daf3bf1445 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -241,12 +241,17 @@ def __eq__(self, actual) -> bool: if actual == self.expected: return True - # If either type is non-numeric, fall back to strict equality. - # NB: we need Complex, rather than just Number, to ensure that __abs__, - # __sub__, and __float__ are defined. + # Check types are non-numeric using duck-typing; if they are not numeric types, + # we consider them unequal because the short-circuit above failed. + required_attrs = [ + "__abs__", + "__float__", + "__rsub__", + "__sub__", + ] if not ( - isinstance(self.expected, (Complex, Decimal)) - and isinstance(actual, (Complex, Decimal)) + all(hasattr(self.expected, attr) for attr in required_attrs) + and all(hasattr(actual, attr) for attr in required_attrs) ): return False diff --git a/testing/python/approx.py b/testing/python/approx.py index 91c1f3f85de..2a1ce11e812 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -6,6 +6,8 @@ from operator import ne from typing import Optional +import attr + import pytest from _pytest.pytester import Pytester from pytest import approx @@ -582,3 +584,37 @@ def __len__(self): expected = MySizedIterable() assert [1, 2, 3, 4] == approx(expected) + + def test_duck_typing(self): + """ + Check that approx() works for objects which implemented the required + numeric methods (#8132). + """ + + @attr.s(auto_attribs=True) + class Container: + value: float + + def __abs__(self) -> float: + return abs(self.value) + + def __sub__(self, other): + if isinstance(other, Container): + return Container(self.value - other.value) + elif isinstance(other, (float, int)): + return self.value - other + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, Container): + return other.value - self.value + elif isinstance(other, (float, int)): + return other - self.value + return NotImplemented + + def __float__(self) -> float: + return self.value + + assert Container(1.0) == approx(1 + 1e-7, rel=5e-7) + assert Container(1.0) != approx(1 + 1e-7, rel=1e-8) + assert Container(1.0) == approx(1.0)