diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..a9827e33b3 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,11 @@ +RELEASE_TYPE: patch + +This release lays the groundwork for automatic rewriting of simple filters, +for example converting ``integers().filter(lambda x: x > 9)`` to +``integers(min_value=10)``. + +Note that this is **not supported yet**, and we will continue to recommend +writing the efficient form directly wherever possible - predicate rewriting +is provided mainly for the benefit of downstream libraries which would +otherwise have to implement it for themselves (e.g. :pypi:`pandera` and +:pypi:`icontract-hypothesis`). See :issue:`2701` for details. diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py new file mode 100644 index 0000000000..afe3d7bde1 --- /dev/null +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -0,0 +1,108 @@ +# This file is part of Hypothesis, which may be found at +# https://github.com/HypothesisWorks/hypothesis/ +# +# Most of this work is copyright (C) 2013-2021 David R. MacIver +# (david@drmaciver.com), but it contains contributions by others. See +# CONTRIBUTING.rst for a full list of people who may hold copyright, and +# consult the git log if you need to determine who owns an individual +# contribution. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. +# +# END HEADER + +"""Tools for understanding predicates, to satisfy them by construction. + +For example:: + + integers().filter(lamda x: x >= 0) -> integers(min_value=0) + +This is intractable in general, but reasonably easy for simple cases involving +numeric bounds, strings with length or regex constraints, and collection lengths - +and those are precisely the most common cases. When they arise in e.g. Pandas +dataframes, it's also pretty painful to do the constructive version by hand in +a library; so we prefer to share all the implementation effort here. +See https://github.com/HypothesisWorks/hypothesis/issues/2701 for details. +""" + +import operator +from decimal import Decimal +from fractions import Fraction +from functools import partial +from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar + +from hypothesis.internal.compat import ceil, floor + +Ex = TypeVar("Ex") +Predicate = Callable[[Ex], bool] + +ConstructivePredicate = Tuple[Mapping[str, Any], Optional[Predicate]] +"""Return kwargs to the appropriate strategy, and the predicate if needed. + +For example:: + + integers().filter(lambda x: x >= 0) + -> {"min_value": 0"}, None + + integers().filter(lambda x: x >= 0 and x % 7) + -> {"min_value": 0"}, lambda x: x % 7 + +At least in principle - for now we usually return the predicate unchanged +if needed. + +We have a separate get-predicate frontend for each "group" of strategies; e.g. +for each numeric type, for strings, for bytes, for collection sizes, etc. +""" + + +def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: + """Shared logic for understanding numeric bounds. + + We then specialise this in the other functions below, to ensure that e.g. + all the values are representable in the types that we're planning to generate + so that the strategy validation doesn't complain. + """ + if ( + type(predicate) is partial + and len(predicate.args) == 1 + and not predicate.keywords + ): + arg = predicate.args[0] + if (isinstance(arg, Decimal) and Decimal.is_snan(arg)) or not isinstance( + arg, (int, float, Fraction, Decimal) + ): + return {}, predicate + options = { + # We're talking about op(arg, x) - the reverse of our usual intuition! + operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x + operator.le: {"min_value": arg}, # lambda x: arg <= x + operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x + operator.ge: {"max_value": arg}, # lambda x: arg >= x + operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x + } + if predicate.func in options: + return options[predicate.func], None + + # TODO: handle lambdas by AST analysis + + return {}, predicate + + +def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: + kwargs, predicate = get_numeric_predicate_bounds(predicate) + + if "min_value" in kwargs: + if kwargs["min_value"] != int(kwargs["min_value"]): + kwargs["min_value"] = ceil(kwargs["min_value"]) + elif kwargs.get("exclude_min", False): + kwargs["min_value"] = int(kwargs["min_value"]) + 1 + if "max_value" in kwargs: + if kwargs["max_value"] != int(kwargs["max_value"]): + kwargs["max_value"] = floor(kwargs["max_value"]) + elif kwargs.get("exclude_max", False): + kwargs["max_value"] = int(kwargs["max_value"]) - 1 + + kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}} + return kwargs, predicate diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/lazy.py b/hypothesis-python/src/hypothesis/strategies/_internal/lazy.py index d828df88ce..45bc75e735 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/lazy.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/lazy.py @@ -20,6 +20,7 @@ arg_string, convert_keyword_arguments, convert_positional_arguments, + get_pretty_function_description, ) from hypothesis.strategies._internal.strategies import SearchStrategy @@ -63,6 +64,10 @@ def unwrap_strategies(s): assert unwrap_depth >= 0 +def _repr_filter(condition): + return f".filter({get_pretty_function_description(condition)})" + + class LazyStrategy(SearchStrategy): """A strategy which is defined purely by conversion to and from another strategy. @@ -70,13 +75,14 @@ class LazyStrategy(SearchStrategy): Its parameter and distribution come from that other strategy. """ - def __init__(self, function, args, kwargs, *, force_repr=None): + def __init__(self, function, args, kwargs, filters=(), *, force_repr=None): SearchStrategy.__init__(self) self.__wrapped_strategy = None self.__representation = force_repr self.function = function self.__args = args self.__kwargs = kwargs + self.__filters = filters @property def supports_find(self): @@ -110,8 +116,19 @@ def wrapped_strategy(self): self.__wrapped_strategy = self.function( *unwrapped_args, **unwrapped_kwargs ) + for f in self.__filters: + self.__wrapped_strategy = self.__wrapped_strategy.filter(f) return self.__wrapped_strategy + def filter(self, condition): + return LazyStrategy( + self.function, + self.__args, + self.__kwargs, + self.__filters + (condition,), + force_repr=f"{self!r}{_repr_filter(condition)}", + ) + def do_validate(self): w = self.wrapped_strategy assert isinstance(w, SearchStrategy), f"{self!r} returned non-strategy {w!r}" @@ -140,9 +157,10 @@ def __repr__(self): for k, v in defaults.items(): if k in kwargs_for_repr and kwargs_for_repr[k] is v: del kwargs_for_repr[k] - self.__representation = "{}({})".format( + self.__representation = "{}({}){}".format( self.function.__name__, arg_string(self.function, _args, kwargs_for_repr, reorder=False), + "".join(map(_repr_filter, self.__filters)), ) return self.__representation diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py b/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py index fe9928af87..a13116ad9b 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/numbers.py @@ -18,6 +18,7 @@ from hypothesis.control import assume, reject from hypothesis.internal.conjecture import floats as flt, utils as d from hypothesis.internal.conjecture.utils import calc_label_from_name +from hypothesis.internal.filtering import get_integer_predicate_bounds from hypothesis.internal.floats import float_of from hypothesis.strategies._internal.strategies import SearchStrategy @@ -51,11 +52,25 @@ def __init__(self, start, end): self.end = end def __repr__(self): - return f"BoundedIntStrategy({self.start}, {self.end})" + return f"integers({self.start}, {self.end})" def do_draw(self, data): return d.integer_range(data, self.start, self.end) + def filter(self, condition): + kwargs, pred = get_integer_predicate_bounds(condition) + start = max(self.start, kwargs.get("min_value", self.start)) + end = min(self.end, kwargs.get("max_value", self.end)) + if start > self.start or end < self.end: + if start > end: + from hypothesis.strategies._internal.core import nothing + + return nothing() + self = type(self)(start, end) + if pred is None: + return self + return super().filter(pred) + NASTY_FLOATS = sorted( [ diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py index cdbd45f9c2..6d100c6381 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py @@ -717,6 +717,7 @@ def __init__(self, strategy, conditions): assert not isinstance(self.filtered_strategy, FilteredStrategy) self.__condition = None + self.__validated = False def calc_is_empty(self, recur): return recur(self.filtered_strategy) @@ -737,6 +738,26 @@ def __repr__(self): def do_validate(self): self.filtered_strategy.validate() + if not self.__validated: + fresh = self.filtered_strategy + for cond in self.flat_conditions: + fresh = fresh.filter(cond) + if isinstance(fresh, FilteredStrategy): + FilteredStrategy.__init__( + self, fresh.filtered_strategy, fresh.flat_conditions + ) + else: + FilteredStrategy.__init__(self, fresh, (lambda _: True,)) + self.__validated = True + + def filter(self, condition): + # Allow strategy rewriting to 'see through' an unhandled predicate. + out = self.filtered_strategy.filter(condition) + if isinstance(out, FilteredStrategy): + return FilteredStrategy( + out.filtered_strategy, self.flat_conditions + out.flat_conditions + ) + return FilteredStrategy(out, self.flat_conditions) @property def condition(self): diff --git a/hypothesis-python/tests/cover/test_direct_strategies.py b/hypothesis-python/tests/cover/test_direct_strategies.py index ab64b7750d..75cffaeaa1 100644 --- a/hypothesis-python/tests/cover/test_direct_strategies.py +++ b/hypothesis-python/tests/cover/test_direct_strategies.py @@ -484,7 +484,7 @@ def test_chained_filter(x): def test_chained_filter_tracks_all_conditions(): s = ds.integers().filter(bool).filter(lambda x: x % 3) - assert len(s.flat_conditions) == 2 + assert len(s.wrapped_strategy.flat_conditions) == 2 @pytest.mark.parametrize("version", [4, 6]) diff --git a/hypothesis-python/tests/cover/test_filter_rewriting.py b/hypothesis-python/tests/cover/test_filter_rewriting.py new file mode 100644 index 0000000000..bc9215ef31 --- /dev/null +++ b/hypothesis-python/tests/cover/test_filter_rewriting.py @@ -0,0 +1,111 @@ +# This file is part of Hypothesis, which may be found at +# https://github.com/HypothesisWorks/hypothesis/ +# +# Most of this work is copyright (C) 2013-2021 David R. MacIver +# (david@drmaciver.com), but it contains contributions by others. See +# CONTRIBUTING.rst for a full list of people who may hold copyright, and +# consult the git log if you need to determine who owns an individual +# contribution. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. +# +# END HEADER + +import operator +from functools import partial +import decimal + +import pytest + +from hypothesis import given, strategies as st +from hypothesis.errors import Unsatisfiable +from hypothesis.strategies._internal.lazy import LazyStrategy +from hypothesis.strategies._internal.numbers import BoundedIntStrategy +from hypothesis.strategies._internal.strategies import FilteredStrategy + +from tests.common.utils import fails_with + + +@pytest.mark.parametrize( + "strategy, predicate, start, end", + [ + # Integers with integer bounds + (st.integers(1, 5), partial(operator.lt, 3), 4, 5), # lambda x: 3 < x + (st.integers(1, 5), partial(operator.le, 3), 3, 5), # lambda x: 3 <= x + (st.integers(1, 5), partial(operator.eq, 3), 3, 3), # lambda x: 3 == x + (st.integers(1, 5), partial(operator.ge, 3), 1, 3), # lambda x: 3 >= x + (st.integers(1, 5), partial(operator.gt, 3), 1, 2), # lambda x: 3 > x + # Integers with non-integer bounds + (st.integers(1, 5), partial(operator.lt, 3.5), 4, 5), + (st.integers(1, 5), partial(operator.le, 3.5), 4, 5), + (st.integers(1, 5), partial(operator.ge, 3.5), 1, 3), + (st.integers(1, 5), partial(operator.gt, 3.5), 1, 3), + ], +) +@given(data=st.data()) +def test_filter_rewriting(data, strategy, predicate, start, end): + s = strategy.filter(predicate) + assert isinstance(s, LazyStrategy) + assert isinstance(s.wrapped_strategy, BoundedIntStrategy) + assert s.wrapped_strategy.start == start + assert s.wrapped_strategy.end == end + value = data.draw(s) + assert predicate(value) + + +@pytest.mark.parametrize( + "s", + [ + st.integers(1, 5).filter(partial(operator.lt, 6)), + st.integers(1, 5).filter(partial(operator.eq, 3.5)), + st.integers(1, 5).filter(partial(operator.eq, "can't compare to strings")), + st.integers(1, 5).filter(partial(operator.ge, 0)), + ], +) +@fails_with(Unsatisfiable) +@given(data=st.data()) +def test_rewrite_unsatisfiable_filter(data, s): + data.draw(s) + + +def test_rewriting_does_not_compare_decimal_snan(): + s = st.integers(1, 5).filter(partial(operator.eq, decimal.Decimal("snan"))) + s.wrapped_strategy + with pytest.raises(decimal.InvalidOperation): + s.example() + + +def mod2(x): + return x % 2 + + +@given( + data=st.data(), + predicates=st.permutations( + [ + partial(operator.lt, 1), + partial(operator.le, 2), + partial(operator.ge, 4), + partial(operator.gt, 5), + mod2, + ] + ), +) +def test_rewrite_filter_chains_with_some_unhandled(data, predicates): + # Set up our strategy + s = st.integers(1, 5) + for p in predicates: + s = s.filter(p) + + # Whatever value we draw is in fact valid for these strategies + value = data.draw(s) + for p in predicates: + assert p(value), f"p={p!r}, value={value}" + + # No matter the order of the filters, we get the same resulting structure + unwrapped = s.wrapped_strategy + assert isinstance(unwrapped, FilteredStrategy) + assert isinstance(unwrapped.filtered_strategy, BoundedIntStrategy) + assert unwrapped.flat_conditions == (mod2,)