Skip to content

Commit

Permalink
Refactor, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed May 6, 2021
1 parent dcc07dc commit b39204e
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 62 deletions.
11 changes: 11 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
RELEASE_TYPE: minor

This release automatically rewrites some simple filters, such as
``integers().filter(lambda x: x > 9)`` to the more efficient
``integers(min_value=10)``, based on the AST of the predicate.

We continue to recommend using the efficient form directly wherever
possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``"
where you already have a simple predicate and translating manually
is really annoying. See :issue:`2701` for ideas about floats and
simple text strategies.
137 changes: 76 additions & 61 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ def unchanged(cls, predicate):
ARG = object()


def convert(node, argname):
def convert(node: ast.AST, argname: str) -> object:
if isinstance(node, ast.Name):
if node.id != argname:
raise ValueError("Non-local variable")
return ARG
return ast.literal_eval(node)


def comp_to_kwargs(a, op, b, *, argname=None):
""" """
if isinstance(a, ast.Name) == isinstance(b, ast.Name):
def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict:
a = convert(x, argname)
b = convert(y, argname)
num = (int, float)
if not (a is ARG and isinstance(b, num)) and not (isinstance(a, num) and b is ARG):
# It would be possible to work out if comparisons between two literals
# are always true or false, but it's too rare to be worth the complexity.
# (and we can't even do `arg == arg`, because what if it's NaN?)
raise ValueError("Can't analyse this comparison")
a = convert(a, argname)
b = convert(b, argname)
assert (a is ARG) != (b is ARG)

if isinstance(op, ast.Lt):
if a is ARG:
Expand All @@ -111,26 +113,18 @@ def comp_to_kwargs(a, op, b, *, argname=None):
raise ValueError("Unhandled comparison operator")


def tidy(kwargs):
if not kwargs["exclude_min"]:
del kwargs["exclude_min"]
if kwargs["min_value"] == -math.inf:
del kwargs["min_value"]
if not kwargs["exclude_max"]:
del kwargs["exclude_max"]
if kwargs["max_value"] == math.inf:
del kwargs["max_value"]
return kwargs


def merge_kwargs(*rest):
def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate:
# This function is just kinda messy. Unfortunately the neatest way
# to do this is just to roll out each case and handle them in turn.
base = {
"min_value": -math.inf,
"max_value": math.inf,
"exclude_min": False,
"exclude_max": False,
}
for kw in rest:
predicate = None
for kw, p in con_predicates:
predicate = p or predicate
if "min_value" in kw:
if kw["min_value"] > base["min_value"]:
base["exclude_min"] = kw.get("exclude_min", False)
Expand All @@ -147,55 +141,56 @@ def merge_kwargs(*rest):
base["exclude_max"] |= kw.get("exclude_max", False)
else:
base["exclude_max"] = False
return tidy(base)

if not base["exclude_min"]:
del base["exclude_min"]
if base["min_value"] == -math.inf:
del base["min_value"]
if not base["exclude_max"]:
del base["exclude_max"]
if base["max_value"] == math.inf:
del base["max_value"]
return ConstructivePredicate(base, predicate)

def numeric_bounds_from_ast(tree, *, argname=None):
"""Take an AST; return a dict of bounds or None.

def numeric_bounds_from_ast(
tree: ast.AST, argname: str, fallback: ConstructivePredicate
) -> ConstructivePredicate:
"""Take an AST; return a ConstructivePredicate.
>>> lambda x: x >= 0
{"min_value": 0}
{"min_value": 0}, None
>>> lambda x: x < 10
{"max_value": 10, "exclude_max": True}
{"max_value": 10, "exclude_max": True}, None
>>> lambda x: x >= y
None
{}, lambda x: x >= y
See also https://greentreesnakes.readthedocs.io/en/latest/
"""
while isinstance(tree, ast.Module) and len(tree.body) == 1:
tree = tree.body[0]
if isinstance(tree, ast.Expr):
while isinstance(tree, ast.Expr):
tree = tree.value

if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1:
assert argname is None
return numeric_bounds_from_ast(tree.body, argname=tree.args.args[0].arg)

if isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1:
assert argname is None
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return):
return None
return numeric_bounds_from_ast(
tree.body[0].value, argname=tree.args.args[0].arg
)

if isinstance(tree, ast.Compare):
ops = tree.ops
vals = tree.comparators
comparisons = [(tree.left, ops[0], vals[0])]
for i, (op, val) in enumerate(zip(ops[1:], vals[1:]), start=1):
comparisons.append((vals[i - 1], op, val))
try:
bounds = [comp_to_kwargs(*x, argname=argname) for x in comparisons]
except ValueError:
return None
return merge_kwargs(*bounds)
bounds = []
for comp in comparisons:
try:
kwargs = comp_to_kwargs(*comp, argname=argname)
bounds.append(ConstructivePredicate(kwargs, None))
except ValueError:
bounds.append(fallback)
return merge_preds(*bounds)

if isinstance(tree, ast.BoolOp) and isinstance(tree.op, ast.And):
bounds = [
numeric_bounds_from_ast(node, argname=argname) for node in tree.values
]
return merge_kwargs(*bounds)
return merge_preds(
*[numeric_bounds_from_ast(node, argname, fallback) for node in tree.values]
)

return None
return fallback


UNSATISFIABLE = ConstructivePredicate.unchanged(lambda _: False)
Expand All @@ -208,6 +203,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
all the values are representable in the types that we're planning to generate
so that the strategy validation doesn't complain.
"""
unchanged = ConstructivePredicate.unchanged(predicate)
if (
isinstance(predicate, partial)
and len(predicate.args) == 1
Expand All @@ -219,7 +215,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
or not isinstance(arg, (int, float, Fraction, Decimal))
or math.isnan(arg)
):
return ConstructivePredicate.unchanged(predicate)
return unchanged
options = {
# We're talking about op(arg, x) - the reverse of our usual intuition!
operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x
Expand All @@ -231,19 +227,38 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
if predicate.func in options:
return ConstructivePredicate(options[predicate.func], None)

# This section is a little complicated, but stepping through with comments should
# help to clarify it. We start by finding the source code for our predicate and
# parsing it to an abstract syntax tree; if this fails for any reason we bail out
# and fall back to standard rejection sampling (a running theme).
try:
if predicate.__name__ == "<lambda>":
source = extract_lambda_source(predicate)
else:
source = inspect.getsource(predicate)
kwargs = numeric_bounds_from_ast(ast.parse(source))
except Exception:
pass
else:
if kwargs is not None:
return ConstructivePredicate(kwargs, None)

return ConstructivePredicate.unchanged(predicate)
tree: ast.AST = ast.parse(source)
except Exception: # pragma: no cover
return unchanged

# Dig down to the relevant subtree - our tree is probably a Module containing
# either a FunctionDef, or an Expr which in turn contains a lambda definition.
while isinstance(tree, ast.Module) and len(tree.body) == 1:
tree = tree.body[0]
while isinstance(tree, ast.Expr):
tree = tree.value

if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1:
return numeric_bounds_from_ast(tree.body, tree.args.args[0].arg, unchanged)
elif isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1:
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return):
# If the body of the function is anything but `return <expr>`,
# i.e. as simple as a lambda, we can't process it (yet).
return unchanged
argname = tree.args.args[0].arg
body = tree.body[0].value
assert isinstance(body, ast.AST)
return numeric_bounds_from_ast(body, argname, unchanged)
return unchanged


def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
Expand Down
46 changes: 45 additions & 1 deletion hypothesis-python/tests/cover/test_filter_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@
(st.integers(), partial(operator.eq, 3), 3, 3),
(st.integers(), partial(operator.ge, 3), None, 3),
(st.integers(), partial(operator.gt, 3), None, 2),
# Simple lambdas
(st.integers(), lambda x: x < 3, None, 2),
(st.integers(), lambda x: x <= 3, None, 3),
(st.integers(), lambda x: x == 3, 3, 3),
(st.integers(), lambda x: x >= 3, 3, None),
(st.integers(), lambda x: x > 3, 4, None),
# Simple lambdas, reverse comparison
(st.integers(), lambda x: 3 > x, None, 2),
(st.integers(), lambda x: 3 >= x, None, 3),
(st.integers(), lambda x: 3 == x, 3, 3),
(st.integers(), lambda x: 3 <= x, 3, None),
(st.integers(), lambda x: 3 < x, 4, None),
# More complicated lambdas
(st.integers(), lambda x: 0 < x < 5, 1, 4),
],
)
@given(data=st.data())
Expand Down Expand Up @@ -115,6 +129,9 @@ def mod2(x):
return x % 2


Y = 2 ** 20


@given(
data=st.data(),
predicates=st.permutations(
Expand All @@ -124,6 +141,8 @@ def mod2(x):
partial(operator.ge, 4),
partial(operator.gt, 5),
mod2,
lambda x: x > 2 or x % 7,
lambda x: 0 < x <= Y,
]
),
)
Expand All @@ -142,4 +161,29 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
unwrapped = s.wrapped_strategy
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, IntegersStrategy)
assert unwrapped.flat_conditions == (mod2,)
for pred in unwrapped.flat_conditions:
assert pred is mod2 or pred.__name__ == "<lambda>"


@pytest.mark.parametrize(
"start, end, predicate",
[
(1, 4, lambda x: 0 < x < 5 and x % 7),
(1, None, lambda x: 0 < x <= Y),
(None, None, lambda x: x == x),
(None, None, lambda x: 1 == 1),
(None, None, lambda x: 1 <= 2),
],
)
@given(data=st.data())
def test_rewriting_partially_understood_filters(data, start, end, predicate):
s = st.integers().filter(predicate).wrapped_strategy

assert isinstance(s, FilteredStrategy)
assert isinstance(s.filtered_strategy, IntegersStrategy)
assert s.filtered_strategy.start == start
assert s.filtered_strategy.end == end
assert s.flat_conditions == (predicate,)

value = data.draw(s)
assert predicate(value)

0 comments on commit b39204e

Please sign in to comment.