Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter-rewriting for length filters #3824

Merged
merged 18 commits into from
Jan 10, 2024
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This introduces the rewriting of length filters on some collection strategies (:issue:`3791`).

Thanks to Reagan Lee for implementing this feature!
51 changes: 39 additions & 12 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def convert(node: ast.AST, argname: str) -> object:
if node.id != argname:
raise ValueError("Non-local variable")
return ARG
if isinstance(node, ast.Call):
if (
isinstance(node.func, ast.Name)
and node.func.id == "len"
and len(node.args) == 1
):
# error unless comparison is to the len *of the lambda arg*
return convert(node.args[0], argname)
return ast.literal_eval(node)


Expand All @@ -86,26 +94,28 @@ def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict
# (and we can't even do `arg == arg`, because what if it's NaN?)
raise ValueError("Can't analyse this comparison")

of_len = {"len": True} if isinstance(x, ast.Call) or isinstance(y, ast.Call) else {}

if isinstance(op, ast.Lt):
if a is ARG:
return {"max_value": b, "exclude_max": True}
return {"min_value": a, "exclude_min": True}
return {"max_value": b, "exclude_max": True, **of_len}
return {"min_value": a, "exclude_min": True, **of_len}
elif isinstance(op, ast.LtE):
if a is ARG:
return {"max_value": b}
return {"min_value": a}
return {"max_value": b, **of_len}
return {"min_value": a, **of_len}
elif isinstance(op, ast.Eq):
if a is ARG:
return {"min_value": b, "max_value": b}
return {"min_value": a, "max_value": a}
return {"min_value": b, "max_value": b, **of_len}
return {"min_value": a, "max_value": a, **of_len}
elif isinstance(op, ast.GtE):
if a is ARG:
return {"min_value": b}
return {"max_value": a}
return {"min_value": b, **of_len}
return {"max_value": a, **of_len}
elif isinstance(op, ast.Gt):
if a is ARG:
return {"min_value": b, "exclude_min": True}
return {"max_value": a, "exclude_max": True}
return {"min_value": b, "exclude_min": True, **of_len}
return {"max_value": a, "exclude_max": True, **of_len}
raise ValueError("Unhandled comparison operator") # e.g. ast.Ne


Expand All @@ -120,6 +130,9 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
}
predicate = None
for kw, p in con_predicates:
assert (
not p or not predicate or p is predicate
), "Can't merge two partially-constructive preds"
predicate = p or predicate
if "min_value" in kw:
if kw["min_value"] > base["min_value"]:
Expand All @@ -134,6 +147,11 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
elif kw["max_value"] == base["max_value"]:
base["exclude_max"] |= kw.get("exclude_max", False)

has_len = {"len" in kw for kw, _ in con_predicates}
assert len(has_len) == 1, "can't mix numeric with length constraints"
if has_len == {True}:
base["len"] = True

if not base["exclude_min"]:
del base["exclude_min"]
if base["min_value"] == -math.inf:
Expand All @@ -154,6 +172,8 @@ def numeric_bounds_from_ast(
{"min_value": 0}, None
>>> lambda x: x < 10
{"max_value": 10, "exclude_max": True}, None
>>> lambda x: len(x) >= 5
{"min_value": 5, "len": True}, None
>>> lambda x: x >= y
{}, lambda x: x >= y

Expand All @@ -169,7 +189,10 @@ def numeric_bounds_from_ast(
for comp in comparisons:
try:
kwargs = comp_to_kwargs(*comp, argname=argname)
bounds.append(ConstructivePredicate(kwargs, None))
# Because `len` could be redefined in the enclosing scope, we *always*
# have to apply the condition as a filter, in addition to rewriting.
pred = fallback.predicate if "len" in kwargs else None
bounds.append(ConstructivePredicate(kwargs, pred))
except ValueError:
bounds.append(fallback)
return merge_preds(*bounds)
Expand Down Expand Up @@ -209,6 +232,9 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
operator.ge: {"max_value": arg}, # lambda x: arg >= x
operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x
# Special-case our default predicates for length bounds
min_len: {"min_value": arg, "len": True},
max_len: {"max_value": arg, "len": True},
}
if predicate.func in options:
return ConstructivePredicate(options[predicate.func], None)
Expand Down Expand Up @@ -270,7 +296,8 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
elif kwargs.get("exclude_max", False):
kwargs["max_value"] = int(kwargs["max_value"]) - 1

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
kw_categories = {"min_value", "max_value", "len"}
kwargs = {k: v for k, v in kwargs.items() if k in kw_categories}
return ConstructivePredicate(kwargs, predicate)


Expand Down
4 changes: 3 additions & 1 deletion hypothesis-python/src/hypothesis/internal/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys
import textwrap
import types
from functools import wraps
from functools import partial, wraps
from io import StringIO
from keyword import iskeyword
from tokenize import COMMENT, detect_encoding, generate_tokens, untokenize
Expand Down Expand Up @@ -432,6 +432,8 @@ def extract_lambda_source(f):


def get_pretty_function_description(f):
if isinstance(f, partial):
return pretty(f)
if not hasattr(f, "__name__"):
return repr(f)
name = f.__name__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy
from hypothesis.internal.conjecture.utils import combine_labels
from hypothesis.internal.filtering import get_integer_predicate_bounds
from hypothesis.internal.reflection import is_identity_function
from hypothesis.strategies._internal.strategies import (
T3,
Expand Down Expand Up @@ -199,7 +200,22 @@ def filter(self, condition):
new = copy.copy(self)
new.min_size = 1
return new
return super().filter(condition)

kwargs, pred = get_integer_predicate_bounds(condition)
if kwargs.get("len") and ("min_value" in kwargs or "max_value" in kwargs):
new = copy.copy(self)
new.min_size = max(self.min_size, kwargs.get("min_value", self.min_size))
new.max_size = min(self.max_size, kwargs.get("max_value", self.max_size))
# Recompute average size; this is cheaper than making it into a property.
new.average_size = min(
max(new.min_size * 2, new.min_size + 5),
0.5 * (new.min_size + new.max_size),
)
if pred is None:
return new
return SearchStrategy.filter(new, condition)

return SearchStrategy.filter(self, condition)


class UniqueListStrategy(ListStrategy):
Expand Down
153 changes: 153 additions & 0 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from hypothesis import given, strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.filtering import max_len, min_len
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies._internal.core import data
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy
from hypothesis.strategies._internal.strings import TextStrategy

from tests.common.utils import fails_with

Expand Down Expand Up @@ -383,3 +385,154 @@ def test_isidentifer_filter_unsatisfiable(al):
def test_filter_floats_can_skip_subnormals(op, attr, value, expected):
base = st.floats(allow_subnormal=False).filter(partial(op, value))
assert getattr(base.wrapped_strategy, attr) == expected


@pytest.mark.parametrize(
"strategy, predicate, start, end",
[
# text with integer bounds
(st.text(min_size=1, max_size=5), partial(min_len, 3), 3, 5),
(st.text(min_size=1, max_size=5), partial(max_len, 3), 1, 3),
# text with only one bound
(st.text(min_size=1), partial(min_len, 3), 3, math.inf),
(st.text(min_size=1), partial(max_len, 3), 1, 3),
(st.text(max_size=5), partial(min_len, 3), 3, 5),
(st.text(max_size=5), partial(max_len, 3), 0, 3),
# Unbounded text
(st.text(), partial(min_len, 3), 3, math.inf),
(st.text(), partial(max_len, 3), 0, 3),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_partial_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)

assert isinstance(s, LazyStrategy)
assert isinstance(inner := unwrap_strategies(s), TextStrategy)
assert inner.min_size == start
assert inner.max_size == end
value = data.draw(s)
assert predicate(value)


@given(data=st.data())
def test_can_rewrite_multiple_length_filters_if_not_lambdas(data):
# This is a key capability for efficient rewriting using the `annotated-types`
# package, although unfortunately we can't do it for lambdas.
s = (
st.text(min_size=1, max_size=5)
.filter(partial(min_len, 2))
.filter(partial(max_len, 4))
)
assert isinstance(s, LazyStrategy)
assert isinstance(inner := unwrap_strategies(s), TextStrategy)
assert inner.min_size == 2
assert inner.max_size == 4
value = data.draw(s)
assert 2 <= len(value) <= 4


@pytest.mark.parametrize(
"predicate, start, end",
[
# Simple lambdas
(lambda x: len(x) < 3, 0, 2),
(lambda x: len(x) <= 3, 0, 3),
(lambda x: len(x) == 3, 3, 3),
(lambda x: len(x) >= 3, 3, math.inf),
(lambda x: len(x) > 3, 4, math.inf),
# Simple lambdas, reverse comparison
(lambda x: 3 > len(x), 0, 2),
(lambda x: 3 >= len(x), 0, 3),
(lambda x: 3 == len(x), 3, 3),
(lambda x: 3 <= len(x), 3, math.inf),
(lambda x: 3 < len(x), 4, math.inf),
# More complicated lambdas
(lambda x: 0 < len(x) < 5, 1, 4),
(lambda x: 0 < len(x) >= 1, 1, math.inf),
(lambda x: 1 > len(x) <= 0, 0, 0),
(lambda x: len(x) > 0 and len(x) > 0, 1, math.inf),
(lambda x: len(x) < 1 and len(x) < 1, 0, 0),
(lambda x: len(x) > 1 and len(x) > 0, 2, math.inf),
(lambda x: len(x) < 1 and len(x) < 2, 0, 0),
],
ids=get_pretty_function_description,
)
@pytest.mark.parametrize(
"strategy",
[
st.text(),
st.lists(st.integers()),
st.lists(st.integers(), unique=True),
st.lists(st.sampled_from([1, 2, 3])),
# TODO: support more collection types. Might require messing around with
# strategy internals, e.g. in MappedStrategy/FilteredStrategy.
# st.binary(),
# st.binary.map(bytearray),
# st.sets(st.integers()),
# st.dictionaries(st.integers(), st.none()),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)
unwrapped = unwrap_strategies(s)
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, type(unwrap_strategies(strategy)))
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
assert predicate(value)


@pytest.mark.parametrize(
"predicate, start, end",
[
# Simple lambdas
(lambda x: len(x) < 3, 0, 2),
(lambda x: len(x) <= 3, 0, 3),
(lambda x: len(x) == 3, 3, 3),
(lambda x: len(x) >= 3, 3, 3), # input max element_count=3
# Simple lambdas, reverse comparison
(lambda x: 3 > len(x), 0, 2),
(lambda x: 3 >= len(x), 0, 3),
(lambda x: 3 == len(x), 3, 3),
(lambda x: 3 <= len(x), 3, 3), # input max element_count=3
# More complicated lambdas
(lambda x: 0 < len(x) < 5, 1, 3), # input max element_count=3
(lambda x: 0 < len(x) >= 1, 1, 3), # input max element_count=3
(lambda x: 1 > len(x) <= 0, 0, 0),
(lambda x: len(x) > 0 and len(x) > 0, 1, 3), # input max element_count=3
(lambda x: len(x) < 1 and len(x) < 1, 0, 0),
(lambda x: len(x) > 1 and len(x) > 0, 2, 3), # input max element_count=3
(lambda x: len(x) < 1 and len(x) < 2, 0, 0),
],
ids=get_pretty_function_description,
)
@pytest.mark.parametrize(
"strategy",
[
st.lists(st.sampled_from([1, 2, 3]), unique=True),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len_unique_elements(
data, strategy, predicate, start, end
):
s = strategy.filter(predicate)
unwrapped = unwrap_strategies(s)
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, type(unwrap_strategies(strategy)))
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
assert predicate(value)
5 changes: 3 additions & 2 deletions hypothesis-python/tests/cover/test_searchstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies import booleans, integers, just, none, tuples
from hypothesis.strategies._internal.utils import to_jsonable

Expand Down Expand Up @@ -77,12 +78,12 @@ def f(u, v):

def test_can_map_nameless():
f = nameless_const(2)
assert repr(f) in repr(integers().map(f))
assert get_pretty_function_description(f) in repr(integers().map(f))


def test_can_flatmap_nameless():
f = nameless_const(just(3))
assert repr(f) in repr(integers().flatmap(f))
assert get_pretty_function_description(f) in repr(integers().flatmap(f))


def test_flatmap_with_invalid_expand():
Expand Down