Skip to content

Commit

Permalink
Merge pull request #3824 from reaganjlee/length-filter
Browse files Browse the repository at this point in the history
Filter-rewriting for length filters
  • Loading branch information
Zac-HD committed Jan 10, 2024
2 parents 25dd7e5 + 50a3c8f commit 7960d08
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 16 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This introduces the rewriting of length filters on some collection strategies (:issue:`3791`).

Thanks to Reagan Lee for implementing this feature!
51 changes: 39 additions & 12 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def convert(node: ast.AST, argname: str) -> object:
if node.id != argname:
raise ValueError("Non-local variable")
return ARG
if isinstance(node, ast.Call):
if (
isinstance(node.func, ast.Name)
and node.func.id == "len"
and len(node.args) == 1
):
# error unless comparison is to the len *of the lambda arg*
return convert(node.args[0], argname)
return ast.literal_eval(node)


Expand All @@ -86,26 +94,28 @@ def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict
# (and we can't even do `arg == arg`, because what if it's NaN?)
raise ValueError("Can't analyse this comparison")

of_len = {"len": True} if isinstance(x, ast.Call) or isinstance(y, ast.Call) else {}

if isinstance(op, ast.Lt):
if a is ARG:
return {"max_value": b, "exclude_max": True}
return {"min_value": a, "exclude_min": True}
return {"max_value": b, "exclude_max": True, **of_len}
return {"min_value": a, "exclude_min": True, **of_len}
elif isinstance(op, ast.LtE):
if a is ARG:
return {"max_value": b}
return {"min_value": a}
return {"max_value": b, **of_len}
return {"min_value": a, **of_len}
elif isinstance(op, ast.Eq):
if a is ARG:
return {"min_value": b, "max_value": b}
return {"min_value": a, "max_value": a}
return {"min_value": b, "max_value": b, **of_len}
return {"min_value": a, "max_value": a, **of_len}
elif isinstance(op, ast.GtE):
if a is ARG:
return {"min_value": b}
return {"max_value": a}
return {"min_value": b, **of_len}
return {"max_value": a, **of_len}
elif isinstance(op, ast.Gt):
if a is ARG:
return {"min_value": b, "exclude_min": True}
return {"max_value": a, "exclude_max": True}
return {"min_value": b, "exclude_min": True, **of_len}
return {"max_value": a, "exclude_max": True, **of_len}
raise ValueError("Unhandled comparison operator") # e.g. ast.Ne


Expand All @@ -120,6 +130,9 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
}
predicate = None
for kw, p in con_predicates:
assert (
not p or not predicate or p is predicate
), "Can't merge two partially-constructive preds"
predicate = p or predicate
if "min_value" in kw:
if kw["min_value"] > base["min_value"]:
Expand All @@ -134,6 +147,11 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
elif kw["max_value"] == base["max_value"]:
base["exclude_max"] |= kw.get("exclude_max", False)

has_len = {"len" in kw for kw, _ in con_predicates}
assert len(has_len) == 1, "can't mix numeric with length constraints"
if has_len == {True}:
base["len"] = True

if not base["exclude_min"]:
del base["exclude_min"]
if base["min_value"] == -math.inf:
Expand All @@ -154,6 +172,8 @@ def numeric_bounds_from_ast(
{"min_value": 0}, None
>>> lambda x: x < 10
{"max_value": 10, "exclude_max": True}, None
>>> lambda x: len(x) >= 5
{"min_value": 5, "len": True}, None
>>> lambda x: x >= y
{}, lambda x: x >= y
Expand All @@ -169,7 +189,10 @@ def numeric_bounds_from_ast(
for comp in comparisons:
try:
kwargs = comp_to_kwargs(*comp, argname=argname)
bounds.append(ConstructivePredicate(kwargs, None))
# Because `len` could be redefined in the enclosing scope, we *always*
# have to apply the condition as a filter, in addition to rewriting.
pred = fallback.predicate if "len" in kwargs else None
bounds.append(ConstructivePredicate(kwargs, pred))
except ValueError:
bounds.append(fallback)
return merge_preds(*bounds)
Expand Down Expand Up @@ -209,6 +232,9 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
operator.ge: {"max_value": arg}, # lambda x: arg >= x
operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x
# Special-case our default predicates for length bounds
min_len: {"min_value": arg, "len": True},
max_len: {"max_value": arg, "len": True},
}
if predicate.func in options:
return ConstructivePredicate(options[predicate.func], None)
Expand Down Expand Up @@ -270,7 +296,8 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
elif kwargs.get("exclude_max", False):
kwargs["max_value"] = int(kwargs["max_value"]) - 1

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
kw_categories = {"min_value", "max_value", "len"}
kwargs = {k: v for k, v in kwargs.items() if k in kw_categories}
return ConstructivePredicate(kwargs, predicate)


Expand Down
4 changes: 3 additions & 1 deletion hypothesis-python/src/hypothesis/internal/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys
import textwrap
import types
from functools import wraps
from functools import partial, wraps
from io import StringIO
from keyword import iskeyword
from tokenize import COMMENT, detect_encoding, generate_tokens, untokenize
Expand Down Expand Up @@ -432,6 +432,8 @@ def extract_lambda_source(f):


def get_pretty_function_description(f):
if isinstance(f, partial):
return pretty(f)
if not hasattr(f, "__name__"):
return repr(f)
name = f.__name__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy
from hypothesis.internal.conjecture.utils import combine_labels
from hypothesis.internal.filtering import get_integer_predicate_bounds
from hypothesis.internal.reflection import is_identity_function
from hypothesis.strategies._internal.strategies import (
T3,
Expand Down Expand Up @@ -199,7 +200,22 @@ def filter(self, condition):
new = copy.copy(self)
new.min_size = 1
return new
return super().filter(condition)

kwargs, pred = get_integer_predicate_bounds(condition)
if kwargs.get("len") and ("min_value" in kwargs or "max_value" in kwargs):
new = copy.copy(self)
new.min_size = max(self.min_size, kwargs.get("min_value", self.min_size))
new.max_size = min(self.max_size, kwargs.get("max_value", self.max_size))
# Recompute average size; this is cheaper than making it into a property.
new.average_size = min(
max(new.min_size * 2, new.min_size + 5),
0.5 * (new.min_size + new.max_size),
)
if pred is None:
return new
return SearchStrategy.filter(new, condition)

return SearchStrategy.filter(self, condition)


class UniqueListStrategy(ListStrategy):
Expand Down
153 changes: 153 additions & 0 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from hypothesis import given, strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.filtering import max_len, min_len
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies._internal.core import data
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy
from hypothesis.strategies._internal.strings import TextStrategy

from tests.common.utils import fails_with

Expand Down Expand Up @@ -383,3 +385,154 @@ def test_isidentifer_filter_unsatisfiable(al):
def test_filter_floats_can_skip_subnormals(op, attr, value, expected):
base = st.floats(allow_subnormal=False).filter(partial(op, value))
assert getattr(base.wrapped_strategy, attr) == expected


@pytest.mark.parametrize(
"strategy, predicate, start, end",
[
# text with integer bounds
(st.text(min_size=1, max_size=5), partial(min_len, 3), 3, 5),
(st.text(min_size=1, max_size=5), partial(max_len, 3), 1, 3),
# text with only one bound
(st.text(min_size=1), partial(min_len, 3), 3, math.inf),
(st.text(min_size=1), partial(max_len, 3), 1, 3),
(st.text(max_size=5), partial(min_len, 3), 3, 5),
(st.text(max_size=5), partial(max_len, 3), 0, 3),
# Unbounded text
(st.text(), partial(min_len, 3), 3, math.inf),
(st.text(), partial(max_len, 3), 0, 3),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_partial_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)

assert isinstance(s, LazyStrategy)
assert isinstance(inner := unwrap_strategies(s), TextStrategy)
assert inner.min_size == start
assert inner.max_size == end
value = data.draw(s)
assert predicate(value)


@given(data=st.data())
def test_can_rewrite_multiple_length_filters_if_not_lambdas(data):
# This is a key capability for efficient rewriting using the `annotated-types`
# package, although unfortunately we can't do it for lambdas.
s = (
st.text(min_size=1, max_size=5)
.filter(partial(min_len, 2))
.filter(partial(max_len, 4))
)
assert isinstance(s, LazyStrategy)
assert isinstance(inner := unwrap_strategies(s), TextStrategy)
assert inner.min_size == 2
assert inner.max_size == 4
value = data.draw(s)
assert 2 <= len(value) <= 4


@pytest.mark.parametrize(
"predicate, start, end",
[
# Simple lambdas
(lambda x: len(x) < 3, 0, 2),
(lambda x: len(x) <= 3, 0, 3),
(lambda x: len(x) == 3, 3, 3),
(lambda x: len(x) >= 3, 3, math.inf),
(lambda x: len(x) > 3, 4, math.inf),
# Simple lambdas, reverse comparison
(lambda x: 3 > len(x), 0, 2),
(lambda x: 3 >= len(x), 0, 3),
(lambda x: 3 == len(x), 3, 3),
(lambda x: 3 <= len(x), 3, math.inf),
(lambda x: 3 < len(x), 4, math.inf),
# More complicated lambdas
(lambda x: 0 < len(x) < 5, 1, 4),
(lambda x: 0 < len(x) >= 1, 1, math.inf),
(lambda x: 1 > len(x) <= 0, 0, 0),
(lambda x: len(x) > 0 and len(x) > 0, 1, math.inf),
(lambda x: len(x) < 1 and len(x) < 1, 0, 0),
(lambda x: len(x) > 1 and len(x) > 0, 2, math.inf),
(lambda x: len(x) < 1 and len(x) < 2, 0, 0),
],
ids=get_pretty_function_description,
)
@pytest.mark.parametrize(
"strategy",
[
st.text(),
st.lists(st.integers()),
st.lists(st.integers(), unique=True),
st.lists(st.sampled_from([1, 2, 3])),
# TODO: support more collection types. Might require messing around with
# strategy internals, e.g. in MappedStrategy/FilteredStrategy.
# st.binary(),
# st.binary.map(bytearray),
# st.sets(st.integers()),
# st.dictionaries(st.integers(), st.none()),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)
unwrapped = unwrap_strategies(s)
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, type(unwrap_strategies(strategy)))
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
assert predicate(value)


@pytest.mark.parametrize(
"predicate, start, end",
[
# Simple lambdas
(lambda x: len(x) < 3, 0, 2),
(lambda x: len(x) <= 3, 0, 3),
(lambda x: len(x) == 3, 3, 3),
(lambda x: len(x) >= 3, 3, 3), # input max element_count=3
# Simple lambdas, reverse comparison
(lambda x: 3 > len(x), 0, 2),
(lambda x: 3 >= len(x), 0, 3),
(lambda x: 3 == len(x), 3, 3),
(lambda x: 3 <= len(x), 3, 3), # input max element_count=3
# More complicated lambdas
(lambda x: 0 < len(x) < 5, 1, 3), # input max element_count=3
(lambda x: 0 < len(x) >= 1, 1, 3), # input max element_count=3
(lambda x: 1 > len(x) <= 0, 0, 0),
(lambda x: len(x) > 0 and len(x) > 0, 1, 3), # input max element_count=3
(lambda x: len(x) < 1 and len(x) < 1, 0, 0),
(lambda x: len(x) > 1 and len(x) > 0, 2, 3), # input max element_count=3
(lambda x: len(x) < 1 and len(x) < 2, 0, 0),
],
ids=get_pretty_function_description,
)
@pytest.mark.parametrize(
"strategy",
[
st.lists(st.sampled_from([1, 2, 3]), unique=True),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len_unique_elements(
data, strategy, predicate, start, end
):
s = strategy.filter(predicate)
unwrapped = unwrap_strategies(s)
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, type(unwrap_strategies(strategy)))
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
assert predicate(value)
5 changes: 3 additions & 2 deletions hypothesis-python/tests/cover/test_searchstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies import booleans, integers, just, none, tuples
from hypothesis.strategies._internal.utils import to_jsonable

Expand Down Expand Up @@ -77,12 +78,12 @@ def f(u, v):

def test_can_map_nameless():
f = nameless_const(2)
assert repr(f) in repr(integers().map(f))
assert get_pretty_function_description(f) in repr(integers().map(f))


def test_can_flatmap_nameless():
f = nameless_const(just(3))
assert repr(f) in repr(integers().flatmap(f))
assert get_pretty_function_description(f) in repr(integers().flatmap(f))


def test_flatmap_with_invalid_expand():
Expand Down

0 comments on commit 7960d08

Please sign in to comment.