diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..b8c410915d --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,10 @@ +RELEASE_TYPE: patch + +This release automatically rewrites some simple filters, such as +``floats().filter(lambda x: x >= 10)`` to the more efficient +``floats(min_value=10)``, based on the AST of the predicate. + +We continue to recommend using the efficient form directly wherever +possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``" +where you already have a simple predicate and translating manually +is really annoying. See :issue:`2701` for details. diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py index a50bdf1b12..ff58897006 100644 --- a/hypothesis-python/src/hypothesis/internal/filtering.py +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -32,6 +32,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar from hypothesis.internal.compat import ceil, floor +from hypothesis.internal.floats import next_down, next_up from hypothesis.internal.reflection import extract_lambda_source Ex = TypeVar("Ex") @@ -274,3 +275,26 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}} return ConstructivePredicate(kwargs, predicate) + + +def get_float_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: + kwargs, predicate = get_numeric_predicate_bounds(predicate) # type: ignore + + if "min_value" in kwargs: + min_value = kwargs["min_value"] + kwargs["min_value"] = float(kwargs["min_value"]) + if min_value < kwargs["min_value"] or ( + min_value == kwargs["min_value"] and kwargs.get("exclude_min", False) + ): + kwargs["min_value"] = next_up(kwargs["min_value"]) + + if "max_value" in kwargs: + max_value = kwargs["max_value"] + kwargs["max_value"] = float(kwargs["max_value"]) + if max_value > kwargs["max_value"] or ( + max_value == kwargs["max_value"] and kwargs.get("exclude_max", False) + ): + kwargs["max_value"] = next_down(kwargs["max_value"]) + + kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}} + return ConstructivePredicate(kwargs, predicate) diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py b/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py index 0fe6fb2aa6..08cef5ceba 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py @@ -18,7 +18,10 @@ from hypothesis.errors import InvalidArgument from hypothesis.internal.conjecture import floats as flt, utils as d from hypothesis.internal.conjecture.utils import calc_label_from_name -from hypothesis.internal.filtering import get_integer_predicate_bounds +from hypothesis.internal.filtering import ( + get_float_predicate_bounds, + get_integer_predicate_bounds, +) from hypothesis.internal.floats import ( float_of, int_to_float, @@ -195,9 +198,10 @@ class FloatStrategy(SearchStrategy): def __init__( self, - min_value: float = -math.inf, - max_value: float = math.inf, - allow_nan: bool = True, + *, + min_value: float, + max_value: float, + allow_nan: bool, # The smallest nonzero number we can represent is usually a subnormal, but may # be the smallest normal if we're running in unsafe denormals-are-zero mode. # While that's usually an explicit error, we do need to handle the case where @@ -292,6 +296,40 @@ def do_draw(self, data): data.stop_example() # (FLOAT_STRATEGY_DO_DRAW_LABEL) return result + def filter(self, condition): + kwargs, pred = get_float_predicate_bounds(condition) + if not kwargs: + return super().filter(pred) + min_bound = max(kwargs.get("min_value", -math.inf), self.min_value) + max_bound = min(kwargs.get("max_value", math.inf), self.max_value) + + # Adjustments for allow_subnormal=False, if any need to be made + if -self.smallest_nonzero_magnitude < min_bound < 0: + min_bound = -0.0 + elif 0 < min_bound < self.smallest_nonzero_magnitude: + min_bound = self.smallest_nonzero_magnitude + if -self.smallest_nonzero_magnitude < max_bound < 0: + max_bound = -self.smallest_nonzero_magnitude + elif 0 < max_bound < self.smallest_nonzero_magnitude: + max_bound = 0.0 + + if min_bound > max_bound: + return nothing() + if ( + min_bound > self.min_value + or self.max_value > max_bound + or (self.allow_nan and (-math.inf < min_bound or max_bound < math.inf)) + ): + self = type(self)( + min_value=min_bound, + max_value=max_bound, + allow_nan=False, + smallest_nonzero_magnitude=self.smallest_nonzero_magnitude, + ) + if pred is None: + return self + return super().filter(pred) + @cacheable @defines_strategy(force_reusable_values=True) @@ -508,14 +546,17 @@ def floats( min_value = float("-inf") if max_value is None: max_value = float("inf") + if not allow_infinity: + min_value = max(min_value, next_up(float("-inf"))) + max_value = min(max_value, next_down(float("inf"))) assert isinstance(min_value, float) assert isinstance(max_value, float) smallest_nonzero_magnitude = ( SMALLEST_SUBNORMAL if allow_subnormal else smallest_normal ) result: SearchStrategy = FloatStrategy( - min_value, - max_value, + min_value=min_value, + max_value=max_value, allow_nan=allow_nan, smallest_nonzero_magnitude=smallest_nonzero_magnitude, ) @@ -529,6 +570,4 @@ def downcast(x): reject() result = result.map(downcast) - if not allow_infinity: - result = result.filter(lambda x: not math.isinf(x)) return result diff --git a/hypothesis-python/tests/cover/test_filter_rewriting.py b/hypothesis-python/tests/cover/test_filter_rewriting.py index dc4dd62a95..b9d3af8203 100644 --- a/hypothesis-python/tests/cover/test_filter_rewriting.py +++ b/hypothesis-python/tests/cover/test_filter_rewriting.py @@ -11,15 +11,18 @@ import decimal import math import operator +from fractions import Fraction from functools import partial +from sys import float_info import pytest from hypothesis import given, strategies as st from hypothesis.errors import HypothesisWarning, Unsatisfiable +from hypothesis.internal.floats import next_down, next_up from hypothesis.internal.reflection import get_pretty_function_description from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies -from hypothesis.strategies._internal.numbers import IntegersStrategy +from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy from hypothesis.strategies._internal.strategies import FilteredStrategy from tests.common.utils import fails_with @@ -87,20 +90,81 @@ def test_filter_rewriting(data, strategy, predicate, start, end): @pytest.mark.parametrize( - "s", + "strategy, predicate, min_value, max_value", [ - st.integers(1, 5).filter(partial(operator.lt, 6)), - st.integers(1, 5).filter(partial(operator.eq, 3.5)), - st.integers(1, 5).filter(partial(operator.eq, "can't compare to strings")), - st.integers(1, 5).filter(partial(operator.ge, 0)), - st.integers(1, 5).filter(partial(operator.lt, math.inf)), - st.integers(1, 5).filter(partial(operator.gt, -math.inf)), + # Floats with integer bounds + (st.floats(1, 5), partial(operator.lt, 3), next_up(3.0), 5), # 3 < x + (st.floats(1, 5), partial(operator.le, 3), 3, 5), # lambda x: 3 <= x + (st.floats(1, 5), partial(operator.eq, 3), 3, 3), # lambda x: 3 == x + (st.floats(1, 5), partial(operator.ge, 3), 1, 3), # lambda x: 3 >= x + (st.floats(1, 5), partial(operator.gt, 3), 1, next_down(3.0)), # 3 > x + # Floats with non-integer bounds + (st.floats(1, 5), partial(operator.lt, 3.5), next_up(3.5), 5), + (st.floats(1, 5), partial(operator.le, 3.5), 3.5, 5), + (st.floats(1, 5), partial(operator.ge, 3.5), 1, 3.5), + (st.floats(1, 5), partial(operator.gt, 3.5), 1, next_down(3.5)), + (st.floats(1, 5), partial(operator.lt, -math.inf), 1, 5), + (st.floats(1, 5), partial(operator.gt, math.inf), 1, 5), + # Floats with only one bound + (st.floats(min_value=1), partial(operator.lt, 3), next_up(3.0), math.inf), + (st.floats(min_value=1), partial(operator.le, 3), 3, math.inf), + (st.floats(max_value=5), partial(operator.ge, 3), -math.inf, 3), + (st.floats(max_value=5), partial(operator.gt, 3), -math.inf, next_down(3.0)), + # Unbounded floats + (st.floats(), partial(operator.lt, 3), next_up(3.0), math.inf), + (st.floats(), partial(operator.le, 3), 3, math.inf), + (st.floats(), partial(operator.eq, 3), 3, 3), + (st.floats(), partial(operator.ge, 3), -math.inf, 3), + (st.floats(), partial(operator.gt, 3), -math.inf, next_down(3.0)), + # Simple lambdas + (st.floats(), lambda x: x < 3, -math.inf, next_down(3.0)), + (st.floats(), lambda x: x <= 3, -math.inf, 3), + (st.floats(), lambda x: x == 3, 3, 3), + (st.floats(), lambda x: x >= 3, 3, math.inf), + (st.floats(), lambda x: x > 3, next_up(3.0), math.inf), + # Simple lambdas, reverse comparison + (st.floats(), lambda x: 3 > x, -math.inf, next_down(3.0)), + (st.floats(), lambda x: 3 >= x, -math.inf, 3), + (st.floats(), lambda x: 3 == x, 3, 3), + (st.floats(), lambda x: 3 <= x, 3, math.inf), + (st.floats(), lambda x: 3 < x, next_up(3.0), math.inf), + # More complicated lambdas + (st.floats(), lambda x: 0 < x < 5, next_up(0.0), next_down(5.0)), + (st.floats(), lambda x: 0 < x >= 1, 1, math.inf), + (st.floats(), lambda x: 1 > x <= 0, -math.inf, 0), + (st.floats(), lambda x: x > 0 and x > 0, next_up(0.0), math.inf), + (st.floats(), lambda x: x < 1 and x < 1, -math.inf, next_down(1.0)), + (st.floats(), lambda x: x > 1 and x > 0, next_up(1.0), math.inf), + (st.floats(), lambda x: x < 1 and x < 2, -math.inf, next_down(1.0)), ], + ids=get_pretty_function_description, ) -@fails_with(Unsatisfiable) @given(data=st.data()) -def test_rewrite_unsatisfiable_filter(data, s): - data.draw(s) +def test_filter_rewriting_floats(data, strategy, predicate, min_value, max_value): + s = strategy.filter(predicate) + assert isinstance(s, LazyStrategy) + assert isinstance(s.wrapped_strategy, FloatStrategy) + assert s.wrapped_strategy.min_value == min_value + assert s.wrapped_strategy.max_value == max_value + value = data.draw(s) + assert predicate(value) + + +@pytest.mark.parametrize( + "pred", + [ + partial(operator.lt, 6), + partial(operator.eq, Fraction(10, 3)), + partial(operator.eq, "can't compare to strings"), + partial(operator.ge, 0), + partial(operator.lt, math.inf), + partial(operator.gt, -math.inf), + ], +) +@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)]) +@fails_with(Unsatisfiable) +def test_rewrite_unsatisfiable_filter(s, pred): + s.filter(pred).example() @given(st.integers(0, 2).filter(partial(operator.ne, 1))) @@ -115,14 +179,8 @@ def test_rewriting_does_not_compare_decimal_snan(): s.example() -@pytest.mark.parametrize( - "strategy, lo, hi", - [ - (st.integers(0, 1), -1, 2), - ], - ids=repr, -) -def test_applying_noop_filter_returns_self(strategy, lo, hi): +@pytest.mark.parametrize("strategy", [st.integers(0, 1), st.floats(0, 1)], ids=repr) +def test_applying_noop_filter_returns_self(strategy): s = strategy.wrapped_strategy s2 = s.filter(partial(operator.le, -1)).filter(partial(operator.ge, 2)) assert s is s2 @@ -135,6 +193,7 @@ def mod2(x): Y = 2**20 +@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)]) @given( data=st.data(), predicates=st.permutations( @@ -149,9 +208,8 @@ def mod2(x): ] ), ) -def test_rewrite_filter_chains_with_some_unhandled(data, predicates): +def test_rewrite_filter_chains_with_some_unhandled(data, predicates, s): # Set up our strategy - s = st.integers(1, 5) for p in predicates: s = s.filter(p) @@ -163,7 +221,7 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates): # No matter the order of the filters, we get the same resulting structure unwrapped = s.wrapped_strategy assert isinstance(unwrapped, FilteredStrategy) - assert isinstance(unwrapped.filtered_strategy, IntegersStrategy) + assert isinstance(unwrapped.filtered_strategy, (IntegersStrategy, FloatStrategy)) for pred in unwrapped.flat_conditions: assert pred is mod2 or pred.__name__ == "" @@ -246,3 +304,17 @@ def test_bumps_min_size_and_filters_for_content_str_methods(method): fs = s.filter(method) assert fs.filtered_strategy.min_size == 1 assert fs.flat_conditions == (method,) + + +@pytest.mark.parametrize( + "op, attr, value, expected", + [ + (operator.lt, "min_value", -float_info.min / 2, 0), + (operator.lt, "min_value", float_info.min / 2, float_info.min), + (operator.gt, "max_value", float_info.min / 2, 0), + (operator.gt, "max_value", -float_info.min / 2, -float_info.min), + ], +) +def test_filter_floats_can_skip_subnormals(op, attr, value, expected): + base = st.floats(allow_subnormal=False).filter(partial(op, value)) + assert getattr(base.wrapped_strategy, attr) == expected