From dcc07dc4de8200a58191fae7042e3ca3392f76d8 Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Tue, 27 Apr 2021 02:03:36 +1000 Subject: [PATCH 1/3] Get lambda filters working --- .../src/hypothesis/internal/filtering.py | 146 +++++++++++++++++- .../tests/cover/test_filtered_strategy.py | 3 +- 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py index e2f8b02107..7e2ca4e149 100644 --- a/hypothesis-python/src/hypothesis/internal/filtering.py +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -27,6 +27,8 @@ See https://github.com/HypothesisWorks/hypothesis/issues/2701 for details. """ +import ast +import inspect import math import operator from decimal import Decimal @@ -35,6 +37,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar from hypothesis.internal.compat import ceil, floor +from hypothesis.internal.reflection import extract_lambda_source Ex = TypeVar("Ex") Predicate = Callable[[Ex], bool] @@ -49,7 +52,7 @@ class ConstructivePredicate(NamedTuple): -> {"min_value": 0"}, None integers().filter(lambda x: x >= 0 and x % 7) - -> {"min_value": 0"}, lambda x: x % 7 + -> {"min_value": 0}, lambda x: x % 7 At least in principle - for now we usually return the predicate unchanged if needed. @@ -66,6 +69,135 @@ def unchanged(cls, predicate): return cls({}, predicate) +ARG = object() + + +def convert(node, argname): + if isinstance(node, ast.Name): + if node.id != argname: + raise ValueError("Non-local variable") + return ARG + return ast.literal_eval(node) + + +def comp_to_kwargs(a, op, b, *, argname=None): + """ """ + if isinstance(a, ast.Name) == isinstance(b, ast.Name): + raise ValueError("Can't analyse this comparison") + a = convert(a, argname) + b = convert(b, argname) + assert (a is ARG) != (b is ARG) + + if isinstance(op, ast.Lt): + if a is ARG: + return {"max_value": b, "exclude_max": True} + return {"min_value": a, "exclude_min": True} + elif isinstance(op, ast.LtE): + if a is ARG: + return {"max_value": b} + return {"min_value": a} + elif isinstance(op, ast.Eq): + if a is ARG: + return {"min_value": b, "max_value": b} + return {"min_value": a, "max_value": a} + elif isinstance(op, ast.GtE): + if a is ARG: + return {"min_value": b} + return {"max_value": a} + elif isinstance(op, ast.Gt): + if a is ARG: + return {"min_value": b, "exclude_min": True} + return {"max_value": a, "exclude_max": True} + raise ValueError("Unhandled comparison operator") + + +def tidy(kwargs): + if not kwargs["exclude_min"]: + del kwargs["exclude_min"] + if kwargs["min_value"] == -math.inf: + del kwargs["min_value"] + if not kwargs["exclude_max"]: + del kwargs["exclude_max"] + if kwargs["max_value"] == math.inf: + del kwargs["max_value"] + return kwargs + + +def merge_kwargs(*rest): + base = { + "min_value": -math.inf, + "max_value": math.inf, + "exclude_min": False, + "exclude_max": False, + } + for kw in rest: + if "min_value" in kw: + if kw["min_value"] > base["min_value"]: + base["exclude_min"] = kw.get("exclude_min", False) + base["min_value"] = kw["min_value"] + elif kw["min_value"] == base["min_value"]: + base["exclude_min"] |= kw.get("exclude_min", False) + else: + base["exclude_min"] = False + if "max_value" in kw: + if kw["max_value"] < base["max_value"]: + base["exclude_max"] = kw.get("exclude_max", False) + base["max_value"] = kw["max_value"] + elif kw["max_value"] == base["max_value"]: + base["exclude_max"] |= kw.get("exclude_max", False) + else: + base["exclude_max"] = False + return tidy(base) + + +def numeric_bounds_from_ast(tree, *, argname=None): + """Take an AST; return a dict of bounds or None. + + >>> lambda x: x >= 0 + {"min_value": 0} + >>> lambda x: x < 10 + {"max_value": 10, "exclude_max": True} + >>> lambda x: x >= y + None + """ + while isinstance(tree, ast.Module) and len(tree.body) == 1: + tree = tree.body[0] + if isinstance(tree, ast.Expr): + tree = tree.value + + if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1: + assert argname is None + return numeric_bounds_from_ast(tree.body, argname=tree.args.args[0].arg) + + if isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1: + assert argname is None + if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return): + return None + return numeric_bounds_from_ast( + tree.body[0].value, argname=tree.args.args[0].arg + ) + + if isinstance(tree, ast.Compare): + ops = tree.ops + vals = tree.comparators + comparisons = [(tree.left, ops[0], vals[0])] + for i, (op, val) in enumerate(zip(ops[1:], vals[1:]), start=1): + comparisons.append((vals[i - 1], op, val)) + try: + bounds = [comp_to_kwargs(*x, argname=argname) for x in comparisons] + except ValueError: + return None + return merge_kwargs(*bounds) + + if isinstance(tree, ast.BoolOp) and isinstance(tree.op, ast.And): + bounds = [ + numeric_bounds_from_ast(node, argname=argname) for node in tree.values + ] + return merge_kwargs(*bounds) + + return None + + UNSATISFIABLE = ConstructivePredicate.unchanged(lambda _: False) @@ -99,7 +231,17 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: if predicate.func in options: return ConstructivePredicate(options[predicate.func], None) - # TODO: handle lambdas by AST analysis + try: + if predicate.__name__ == "": + source = extract_lambda_source(predicate) + else: + source = inspect.getsource(predicate) + kwargs = numeric_bounds_from_ast(ast.parse(source)) + except Exception: + pass + else: + if kwargs is not None: + return ConstructivePredicate(kwargs, None) return ConstructivePredicate.unchanged(predicate) diff --git a/hypothesis-python/tests/cover/test_filtered_strategy.py b/hypothesis-python/tests/cover/test_filtered_strategy.py index 97ea7bbf9d..428b2a2086 100644 --- a/hypothesis-python/tests/cover/test_filtered_strategy.py +++ b/hypothesis-python/tests/cover/test_filtered_strategy.py @@ -19,7 +19,8 @@ def test_filter_iterations_are_marked_as_discarded(): - x = st.integers(0, 255).filter(lambda x: x == 0) + variable_equal_to_zero = 0 # non-local references disables filter-rewriting + x = st.integers(0, 255).filter(lambda x: x == variable_equal_to_zero) data = ConjectureData.for_buffer([2, 1, 0]) From b39204e47ba428f7d98cac28020db81b3e4d944e Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Wed, 28 Apr 2021 00:42:45 +1000 Subject: [PATCH 2/3] Refactor, add tests --- hypothesis-python/RELEASE.rst | 11 ++ .../src/hypothesis/internal/filtering.py | 137 ++++++++++-------- .../tests/cover/test_filter_rewriting.py | 46 +++++- 3 files changed, 132 insertions(+), 62 deletions(-) create mode 100644 hypothesis-python/RELEASE.rst diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..0eef32929a --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,11 @@ +RELEASE_TYPE: minor + +This release automatically rewrites some simple filters, such as +``integers().filter(lambda x: x > 9)`` to the more efficient +``integers(min_value=10)``, based on the AST of the predicate. + +We continue to recommend using the efficient form directly wherever +possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``" +where you already have a simple predicate and translating manually +is really annoying. See :issue:`2701` for ideas about floats and +simple text strategies. diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py index 7e2ca4e149..28856197ed 100644 --- a/hypothesis-python/src/hypothesis/internal/filtering.py +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -72,7 +72,7 @@ def unchanged(cls, predicate): ARG = object() -def convert(node, argname): +def convert(node: ast.AST, argname: str) -> object: if isinstance(node, ast.Name): if node.id != argname: raise ValueError("Non-local variable") @@ -80,13 +80,15 @@ def convert(node, argname): return ast.literal_eval(node) -def comp_to_kwargs(a, op, b, *, argname=None): - """ """ - if isinstance(a, ast.Name) == isinstance(b, ast.Name): +def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict: + a = convert(x, argname) + b = convert(y, argname) + num = (int, float) + if not (a is ARG and isinstance(b, num)) and not (isinstance(a, num) and b is ARG): + # It would be possible to work out if comparisons between two literals + # are always true or false, but it's too rare to be worth the complexity. + # (and we can't even do `arg == arg`, because what if it's NaN?) raise ValueError("Can't analyse this comparison") - a = convert(a, argname) - b = convert(b, argname) - assert (a is ARG) != (b is ARG) if isinstance(op, ast.Lt): if a is ARG: @@ -111,26 +113,18 @@ def comp_to_kwargs(a, op, b, *, argname=None): raise ValueError("Unhandled comparison operator") -def tidy(kwargs): - if not kwargs["exclude_min"]: - del kwargs["exclude_min"] - if kwargs["min_value"] == -math.inf: - del kwargs["min_value"] - if not kwargs["exclude_max"]: - del kwargs["exclude_max"] - if kwargs["max_value"] == math.inf: - del kwargs["max_value"] - return kwargs - - -def merge_kwargs(*rest): +def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate: + # This function is just kinda messy. Unfortunately the neatest way + # to do this is just to roll out each case and handle them in turn. base = { "min_value": -math.inf, "max_value": math.inf, "exclude_min": False, "exclude_max": False, } - for kw in rest: + predicate = None + for kw, p in con_predicates: + predicate = p or predicate if "min_value" in kw: if kw["min_value"] > base["min_value"]: base["exclude_min"] = kw.get("exclude_min", False) @@ -147,55 +141,56 @@ def merge_kwargs(*rest): base["exclude_max"] |= kw.get("exclude_max", False) else: base["exclude_max"] = False - return tidy(base) + if not base["exclude_min"]: + del base["exclude_min"] + if base["min_value"] == -math.inf: + del base["min_value"] + if not base["exclude_max"]: + del base["exclude_max"] + if base["max_value"] == math.inf: + del base["max_value"] + return ConstructivePredicate(base, predicate) -def numeric_bounds_from_ast(tree, *, argname=None): - """Take an AST; return a dict of bounds or None. + +def numeric_bounds_from_ast( + tree: ast.AST, argname: str, fallback: ConstructivePredicate +) -> ConstructivePredicate: + """Take an AST; return a ConstructivePredicate. >>> lambda x: x >= 0 - {"min_value": 0} + {"min_value": 0}, None >>> lambda x: x < 10 - {"max_value": 10, "exclude_max": True} + {"max_value": 10, "exclude_max": True}, None >>> lambda x: x >= y - None + {}, lambda x: x >= y + + See also https://greentreesnakes.readthedocs.io/en/latest/ """ - while isinstance(tree, ast.Module) and len(tree.body) == 1: - tree = tree.body[0] - if isinstance(tree, ast.Expr): + while isinstance(tree, ast.Expr): tree = tree.value - if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1: - assert argname is None - return numeric_bounds_from_ast(tree.body, argname=tree.args.args[0].arg) - - if isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1: - assert argname is None - if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return): - return None - return numeric_bounds_from_ast( - tree.body[0].value, argname=tree.args.args[0].arg - ) - if isinstance(tree, ast.Compare): ops = tree.ops vals = tree.comparators comparisons = [(tree.left, ops[0], vals[0])] for i, (op, val) in enumerate(zip(ops[1:], vals[1:]), start=1): comparisons.append((vals[i - 1], op, val)) - try: - bounds = [comp_to_kwargs(*x, argname=argname) for x in comparisons] - except ValueError: - return None - return merge_kwargs(*bounds) + bounds = [] + for comp in comparisons: + try: + kwargs = comp_to_kwargs(*comp, argname=argname) + bounds.append(ConstructivePredicate(kwargs, None)) + except ValueError: + bounds.append(fallback) + return merge_preds(*bounds) if isinstance(tree, ast.BoolOp) and isinstance(tree.op, ast.And): - bounds = [ - numeric_bounds_from_ast(node, argname=argname) for node in tree.values - ] - return merge_kwargs(*bounds) + return merge_preds( + *[numeric_bounds_from_ast(node, argname, fallback) for node in tree.values] + ) - return None + return fallback UNSATISFIABLE = ConstructivePredicate.unchanged(lambda _: False) @@ -208,6 +203,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: all the values are representable in the types that we're planning to generate so that the strategy validation doesn't complain. """ + unchanged = ConstructivePredicate.unchanged(predicate) if ( isinstance(predicate, partial) and len(predicate.args) == 1 @@ -219,7 +215,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: or not isinstance(arg, (int, float, Fraction, Decimal)) or math.isnan(arg) ): - return ConstructivePredicate.unchanged(predicate) + return unchanged options = { # We're talking about op(arg, x) - the reverse of our usual intuition! operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x @@ -231,19 +227,38 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: if predicate.func in options: return ConstructivePredicate(options[predicate.func], None) + # This section is a little complicated, but stepping through with comments should + # help to clarify it. We start by finding the source code for our predicate and + # parsing it to an abstract syntax tree; if this fails for any reason we bail out + # and fall back to standard rejection sampling (a running theme). try: if predicate.__name__ == "": source = extract_lambda_source(predicate) else: source = inspect.getsource(predicate) - kwargs = numeric_bounds_from_ast(ast.parse(source)) - except Exception: - pass - else: - if kwargs is not None: - return ConstructivePredicate(kwargs, None) - - return ConstructivePredicate.unchanged(predicate) + tree: ast.AST = ast.parse(source) + except Exception: # pragma: no cover + return unchanged + + # Dig down to the relevant subtree - our tree is probably a Module containing + # either a FunctionDef, or an Expr which in turn contains a lambda definition. + while isinstance(tree, ast.Module) and len(tree.body) == 1: + tree = tree.body[0] + while isinstance(tree, ast.Expr): + tree = tree.value + + if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1: + return numeric_bounds_from_ast(tree.body, tree.args.args[0].arg, unchanged) + elif isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1: + if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return): + # If the body of the function is anything but `return `, + # i.e. as simple as a lambda, we can't process it (yet). + return unchanged + argname = tree.args.args[0].arg + body = tree.body[0].value + assert isinstance(body, ast.AST) + return numeric_bounds_from_ast(body, argname, unchanged) + return unchanged def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: diff --git a/hypothesis-python/tests/cover/test_filter_rewriting.py b/hypothesis-python/tests/cover/test_filter_rewriting.py index 4183037225..6697a0ff58 100644 --- a/hypothesis-python/tests/cover/test_filter_rewriting.py +++ b/hypothesis-python/tests/cover/test_filter_rewriting.py @@ -56,6 +56,20 @@ (st.integers(), partial(operator.eq, 3), 3, 3), (st.integers(), partial(operator.ge, 3), None, 3), (st.integers(), partial(operator.gt, 3), None, 2), + # Simple lambdas + (st.integers(), lambda x: x < 3, None, 2), + (st.integers(), lambda x: x <= 3, None, 3), + (st.integers(), lambda x: x == 3, 3, 3), + (st.integers(), lambda x: x >= 3, 3, None), + (st.integers(), lambda x: x > 3, 4, None), + # Simple lambdas, reverse comparison + (st.integers(), lambda x: 3 > x, None, 2), + (st.integers(), lambda x: 3 >= x, None, 3), + (st.integers(), lambda x: 3 == x, 3, 3), + (st.integers(), lambda x: 3 <= x, 3, None), + (st.integers(), lambda x: 3 < x, 4, None), + # More complicated lambdas + (st.integers(), lambda x: 0 < x < 5, 1, 4), ], ) @given(data=st.data()) @@ -115,6 +129,9 @@ def mod2(x): return x % 2 +Y = 2 ** 20 + + @given( data=st.data(), predicates=st.permutations( @@ -124,6 +141,8 @@ def mod2(x): partial(operator.ge, 4), partial(operator.gt, 5), mod2, + lambda x: x > 2 or x % 7, + lambda x: 0 < x <= Y, ] ), ) @@ -142,4 +161,29 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates): unwrapped = s.wrapped_strategy assert isinstance(unwrapped, FilteredStrategy) assert isinstance(unwrapped.filtered_strategy, IntegersStrategy) - assert unwrapped.flat_conditions == (mod2,) + for pred in unwrapped.flat_conditions: + assert pred is mod2 or pred.__name__ == "" + + +@pytest.mark.parametrize( + "start, end, predicate", + [ + (1, 4, lambda x: 0 < x < 5 and x % 7), + (1, None, lambda x: 0 < x <= Y), + (None, None, lambda x: x == x), + (None, None, lambda x: 1 == 1), + (None, None, lambda x: 1 <= 2), + ], +) +@given(data=st.data()) +def test_rewriting_partially_understood_filters(data, start, end, predicate): + s = st.integers().filter(predicate).wrapped_strategy + + assert isinstance(s, FilteredStrategy) + assert isinstance(s.filtered_strategy, IntegersStrategy) + assert s.filtered_strategy.start == start + assert s.filtered_strategy.end == end + assert s.flat_conditions == (predicate,) + + value = data.draw(s) + assert predicate(value) From 9cb3bf4f341536cab2b32e6aa4c3469417e34850 Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Wed, 28 Apr 2021 02:41:15 +1000 Subject: [PATCH 3/3] Improve coverage --- .../src/hypothesis/internal/filtering.py | 11 ++-------- .../tests/cover/test_filter_rewriting.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/hypothesis-python/src/hypothesis/internal/filtering.py b/hypothesis-python/src/hypothesis/internal/filtering.py index 28856197ed..01520a7622 100644 --- a/hypothesis-python/src/hypothesis/internal/filtering.py +++ b/hypothesis-python/src/hypothesis/internal/filtering.py @@ -110,7 +110,7 @@ def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict if a is ARG: return {"min_value": b, "exclude_min": True} return {"max_value": a, "exclude_max": True} - raise ValueError("Unhandled comparison operator") + raise ValueError("Unhandled comparison operator") # e.g. ast.Ne def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate: @@ -131,16 +131,12 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate base["min_value"] = kw["min_value"] elif kw["min_value"] == base["min_value"]: base["exclude_min"] |= kw.get("exclude_min", False) - else: - base["exclude_min"] = False if "max_value" in kw: if kw["max_value"] < base["max_value"]: base["exclude_max"] = kw.get("exclude_max", False) base["max_value"] = kw["max_value"] elif kw["max_value"] == base["max_value"]: base["exclude_max"] |= kw.get("exclude_max", False) - else: - base["exclude_max"] = False if not base["exclude_min"]: del base["exclude_min"] @@ -167,9 +163,6 @@ def numeric_bounds_from_ast( See also https://greentreesnakes.readthedocs.io/en/latest/ """ - while isinstance(tree, ast.Expr): - tree = tree.value - if isinstance(tree, ast.Compare): ops = tree.ops vals = tree.comparators @@ -237,7 +230,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate: else: source = inspect.getsource(predicate) tree: ast.AST = ast.parse(source) - except Exception: # pragma: no cover + except Exception: return unchanged # Dig down to the relevant subtree - our tree is probably a Module containing diff --git a/hypothesis-python/tests/cover/test_filter_rewriting.py b/hypothesis-python/tests/cover/test_filter_rewriting.py index 6697a0ff58..4ca67a55ed 100644 --- a/hypothesis-python/tests/cover/test_filter_rewriting.py +++ b/hypothesis-python/tests/cover/test_filter_rewriting.py @@ -22,6 +22,7 @@ from hypothesis import given, strategies as st from hypothesis.errors import Unsatisfiable +from hypothesis.internal.reflection import get_pretty_function_description from hypothesis.strategies._internal.lazy import LazyStrategy from hypothesis.strategies._internal.numbers import IntegersStrategy from hypothesis.strategies._internal.strategies import FilteredStrategy @@ -70,7 +71,14 @@ (st.integers(), lambda x: 3 < x, 4, None), # More complicated lambdas (st.integers(), lambda x: 0 < x < 5, 1, 4), + (st.integers(), lambda x: 0 < x >= 1, 1, None), + (st.integers(), lambda x: 1 > x <= 0, None, 0), + (st.integers(), lambda x: x > 0 and x > 0, 1, None), + (st.integers(), lambda x: x < 1 and x < 1, None, 0), + (st.integers(), lambda x: x > 1 and x > 0, 2, None), + (st.integers(), lambda x: x < 1 and x < 2, None, 0), ], + ids=get_pretty_function_description, ) @given(data=st.data()) def test_filter_rewriting(data, strategy, predicate, start, end): @@ -165,14 +173,27 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates): assert pred is mod2 or pred.__name__ == "" +class NotAFunction: + def __call__(self, bar): + return True + + +lambda_without_source = eval("lambda x: x > 2", {}, {}) + + @pytest.mark.parametrize( "start, end, predicate", [ (1, 4, lambda x: 0 < x < 5 and x % 7), + (0, 9, lambda x: 0 <= x < 10 and x % 3), (1, None, lambda x: 0 < x <= Y), (None, None, lambda x: x == x), (None, None, lambda x: 1 == 1), (None, None, lambda x: 1 <= 2), + (None, None, lambda x: x != 0), + (None, None, NotAFunction()), + (None, None, lambda_without_source), + (None, None, lambda x, y=2: x >= 0), ], ) @given(data=st.data())