diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..0eef32929a --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,11 @@ +RELEASE_TYPE: minor + +This release automatically rewrites some simple filters, such as +``integers().filter(lambda x: x > 9)`` to the more efficient +``integers(min_value=10)``, based on the AST of the predicate. + +We continue to recommend using the efficient form directly wherever +possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``" +where you already have a simple predicate and translating manually +is really annoying. See :issue:`2701` for ideas about floats and +simple text strategies. diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py index e2f8b02107..01520a7622 100644 --- a/hypothesis-python/src/hypothesis/internal/filtering.py +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -27,6 +27,8 @@ See https://github.com/HypothesisWorks/hypothesis/issues/2701 for details. """ +import ast +import inspect import math import operator from decimal import Decimal @@ -35,6 +37,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar from hypothesis.internal.compat import ceil, floor +from hypothesis.internal.reflection import extract_lambda_source Ex = TypeVar("Ex") Predicate = Callable[[Ex], bool] @@ -49,7 +52,7 @@ class ConstructivePredicate(NamedTuple): -> {"min_value": 0"}, None integers().filter(lambda x: x >= 0 and x % 7) - -> {"min_value": 0"}, lambda x: x % 7 + -> {"min_value": 0}, lambda x: x % 7 At least in principle - for now we usually return the predicate unchanged if needed. @@ -66,6 +69,123 @@ def unchanged(cls, predicate): return cls({}, predicate) +ARG = object() + + +def convert(node: ast.AST, argname: str) -> object: + if isinstance(node, ast.Name): + if node.id != argname: + raise ValueError("Non-local variable") + return ARG + return ast.literal_eval(node) + + +def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict: + a = convert(x, argname) + b = convert(y, argname) + num = (int, float) + if not (a is ARG and isinstance(b, num)) and not (isinstance(a, num) and b is ARG): + # It would be possible to work out if comparisons between two literals + # are always true or false, but it's too rare to be worth the complexity. + # (and we can't even do `arg == arg`, because what if it's NaN?) + raise ValueError("Can't analyse this comparison") + + if isinstance(op, ast.Lt): + if a is ARG: + return {"max_value": b, "exclude_max": True} + return {"min_value": a, "exclude_min": True} + elif isinstance(op, ast.LtE): + if a is ARG: + return {"max_value": b} + return {"min_value": a} + elif isinstance(op, ast.Eq): + if a is ARG: + return {"min_value": b, "max_value": b} + return {"min_value": a, "max_value": a} + elif isinstance(op, ast.GtE): + if a is ARG: + return {"min_value": b} + return {"max_value": a} + elif isinstance(op, ast.Gt): + if a is ARG: + return {"min_value": b, "exclude_min": True} + return {"max_value": a, "exclude_max": True} + raise ValueError("Unhandled comparison operator") # e.g. ast.Ne + + +def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate: + # This function is just kinda messy. Unfortunately the neatest way + # to do this is just to roll out each case and handle them in turn. + base = { + "min_value": -math.inf, + "max_value": math.inf, + "exclude_min": False, + "exclude_max": False, + } + predicate = None + for kw, p in con_predicates: + predicate = p or predicate + if "min_value" in kw: + if kw["min_value"] > base["min_value"]: + base["exclude_min"] = kw.get("exclude_min", False) + base["min_value"] = kw["min_value"] + elif kw["min_value"] == base["min_value"]: + base["exclude_min"] |= kw.get("exclude_min", False) + if "max_value" in kw: + if kw["max_value"] < base["max_value"]: + base["exclude_max"] = kw.get("exclude_max", False) + base["max_value"] = kw["max_value"] + elif kw["max_value"] == base["max_value"]: + base["exclude_max"] |= kw.get("exclude_max", False) + + if not base["exclude_min"]: + del base["exclude_min"] + if base["min_value"] == -math.inf: + del base["min_value"] + if not base["exclude_max"]: + del base["exclude_max"] + if base["max_value"] == math.inf: + del base["max_value"] + return ConstructivePredicate(base, predicate) + + +def numeric_bounds_from_ast( + tree: ast.AST, argname: str, fallback: ConstructivePredicate +) -> ConstructivePredicate: + """Take an AST; return a ConstructivePredicate. + + >>> lambda x: x >= 0 + {"min_value": 0}, None + >>> lambda x: x < 10 + {"max_value": 10, "exclude_max": True}, None + >>> lambda x: x >= y + {}, lambda x: x >= y + + See also https://greentreesnakes.readthedocs.io/en/latest/ + """ + if isinstance(tree, ast.Compare): + ops = tree.ops + vals = tree.comparators + comparisons = [(tree.left, ops[0], vals[0])] + for i, (op, val) in enumerate(zip(ops[1:], vals[1:]), start=1): + comparisons.append((vals[i - 1], op, val)) + bounds = [] + for comp in comparisons: + try: + kwargs = comp_to_kwargs(*comp, argname=argname) + bounds.append(ConstructivePredicate(kwargs, None)) + except ValueError: + bounds.append(fallback) + return merge_preds(*bounds) + + if isinstance(tree, ast.BoolOp) and isinstance(tree.op, ast.And): + return merge_preds( + *[numeric_bounds_from_ast(node, argname, fallback) for node in tree.values] + ) + + return fallback + + UNSATISFIABLE = ConstructivePredicate.unchanged(lambda _: False) @@ -76,6 +196,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: all the values are representable in the types that we're planning to generate so that the strategy validation doesn't complain. """ + unchanged = ConstructivePredicate.unchanged(predicate) if ( isinstance(predicate, partial) and len(predicate.args) == 1 @@ -87,7 +208,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: or not isinstance(arg, (int, float, Fraction, Decimal)) or math.isnan(arg) ): - return ConstructivePredicate.unchanged(predicate) + return unchanged options = { # We're talking about op(arg, x) - the reverse of our usual intuition! operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x @@ -99,9 +220,38 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: if predicate.func in options: return ConstructivePredicate(options[predicate.func], None) - # TODO: handle lambdas by AST analysis + # This section is a little complicated, but stepping through with comments should + # help to clarify it. We start by finding the source code for our predicate and + # parsing it to an abstract syntax tree; if this fails for any reason we bail out + # and fall back to standard rejection sampling (a running theme). + try: + if predicate.__name__ == "": + source = extract_lambda_source(predicate) + else: + source = inspect.getsource(predicate) + tree: ast.AST = ast.parse(source) + except Exception: + return unchanged + + # Dig down to the relevant subtree - our tree is probably a Module containing + # either a FunctionDef, or an Expr which in turn contains a lambda definition. + while isinstance(tree, ast.Module) and len(tree.body) == 1: + tree = tree.body[0] + while isinstance(tree, ast.Expr): + tree = tree.value - return ConstructivePredicate.unchanged(predicate) + if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1: + return numeric_bounds_from_ast(tree.body, tree.args.args[0].arg, unchanged) + elif isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1: + if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return): + # If the body of the function is anything but `return `, + # i.e. as simple as a lambda, we can't process it (yet). + return unchanged + argname = tree.args.args[0].arg + body = tree.body[0].value + assert isinstance(body, ast.AST) + return numeric_bounds_from_ast(body, argname, unchanged) + return unchanged def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: diff --git a/hypothesis-python/tests/cover/test_filter_rewriting.py b/hypothesis-python/tests/cover/test_filter_rewriting.py index 4183037225..4ca67a55ed 100644 --- a/hypothesis-python/tests/cover/test_filter_rewriting.py +++ b/hypothesis-python/tests/cover/test_filter_rewriting.py @@ -22,6 +22,7 @@ from hypothesis import given, strategies as st from hypothesis.errors import Unsatisfiable +from hypothesis.internal.reflection import get_pretty_function_description from hypothesis.strategies._internal.lazy import LazyStrategy from hypothesis.strategies._internal.numbers import IntegersStrategy from hypothesis.strategies._internal.strategies import FilteredStrategy @@ -56,7 +57,28 @@ (st.integers(), partial(operator.eq, 3), 3, 3), (st.integers(), partial(operator.ge, 3), None, 3), (st.integers(), partial(operator.gt, 3), None, 2), + # Simple lambdas + (st.integers(), lambda x: x < 3, None, 2), + (st.integers(), lambda x: x <= 3, None, 3), + (st.integers(), lambda x: x == 3, 3, 3), + (st.integers(), lambda x: x >= 3, 3, None), + (st.integers(), lambda x: x > 3, 4, None), + # Simple lambdas, reverse comparison + (st.integers(), lambda x: 3 > x, None, 2), + (st.integers(), lambda x: 3 >= x, None, 3), + (st.integers(), lambda x: 3 == x, 3, 3), + (st.integers(), lambda x: 3 <= x, 3, None), + (st.integers(), lambda x: 3 < x, 4, None), + # More complicated lambdas + (st.integers(), lambda x: 0 < x < 5, 1, 4), + (st.integers(), lambda x: 0 < x >= 1, 1, None), + (st.integers(), lambda x: 1 > x <= 0, None, 0), + (st.integers(), lambda x: x > 0 and x > 0, 1, None), + (st.integers(), lambda x: x < 1 and x < 1, None, 0), + (st.integers(), lambda x: x > 1 and x > 0, 2, None), + (st.integers(), lambda x: x < 1 and x < 2, None, 0), ], + ids=get_pretty_function_description, ) @given(data=st.data()) def test_filter_rewriting(data, strategy, predicate, start, end): @@ -115,6 +137,9 @@ def mod2(x): return x % 2 +Y = 2 ** 20 + + @given( data=st.data(), predicates=st.permutations( @@ -124,6 +149,8 @@ def mod2(x): partial(operator.ge, 4), partial(operator.gt, 5), mod2, + lambda x: x > 2 or x % 7, + lambda x: 0 < x <= Y, ] ), ) @@ -142,4 +169,42 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates): unwrapped = s.wrapped_strategy assert isinstance(unwrapped, FilteredStrategy) assert isinstance(unwrapped.filtered_strategy, IntegersStrategy) - assert unwrapped.flat_conditions == (mod2,) + for pred in unwrapped.flat_conditions: + assert pred is mod2 or pred.__name__ == "" + + +class NotAFunction: + def __call__(self, bar): + return True + + +lambda_without_source = eval("lambda x: x > 2", {}, {}) + + +@pytest.mark.parametrize( + "start, end, predicate", + [ + (1, 4, lambda x: 0 < x < 5 and x % 7), + (0, 9, lambda x: 0 <= x < 10 and x % 3), + (1, None, lambda x: 0 < x <= Y), + (None, None, lambda x: x == x), + (None, None, lambda x: 1 == 1), + (None, None, lambda x: 1 <= 2), + (None, None, lambda x: x != 0), + (None, None, NotAFunction()), + (None, None, lambda_without_source), + (None, None, lambda x, y=2: x >= 0), + ], +) +@given(data=st.data()) +def test_rewriting_partially_understood_filters(data, start, end, predicate): + s = st.integers().filter(predicate).wrapped_strategy + + assert isinstance(s, FilteredStrategy) + assert isinstance(s.filtered_strategy, IntegersStrategy) + assert s.filtered_strategy.start == start + assert s.filtered_strategy.end == end + assert s.flat_conditions == (predicate,) + + value = data.draw(s) + assert predicate(value) diff --git a/hypothesis-python/tests/cover/test_filtered_strategy.py b/hypothesis-python/tests/cover/test_filtered_strategy.py index 97ea7bbf9d..428b2a2086 100644 --- a/hypothesis-python/tests/cover/test_filtered_strategy.py +++ b/hypothesis-python/tests/cover/test_filtered_strategy.py @@ -19,7 +19,8 @@ def test_filter_iterations_are_marked_as_discarded(): - x = st.integers(0, 255).filter(lambda x: x == 0) + variable_equal_to_zero = 0 # non-local references disables filter-rewriting + x = st.integers(0, 255).filter(lambda x: x == variable_equal_to_zero) data = ConjectureData.for_buffer([2, 1, 0])