Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter-rewriting for length filters #3824

Merged
merged 18 commits into from
Jan 10, 2024
4 changes: 4 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
RELEASE_TYPE: minor

This introduces the rewriting of length filters in ``text`` strategies,
utilizing ``partial`` and ``lambda`` functions for efficient handling of length constraints.
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 38 additions & 13 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from decimal import Decimal
from fractions import Fraction
from functools import partial
from typing import Any, Callable, Collection, Dict, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Collection, Dict, NamedTuple, Optional, TypeVar, Union

from hypothesis.internal.compat import ceil, floor
from hypothesis.internal.floats import next_down, next_up
Expand Down Expand Up @@ -73,6 +73,10 @@ def convert(node: ast.AST, argname: str) -> object:
if node.id != argname:
raise ValueError("Non-local variable")
return ARG
if isinstance(node, ast.Call):
# if node.func.id == "len":
if isinstance(node.func, ast.Name) and node.func.id == "len":
return ARG
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
return ast.literal_eval(node)


Expand All @@ -86,27 +90,40 @@ def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict
# (and we can't even do `arg == arg`, because what if it's NaN?)
raise ValueError("Can't analyse this comparison")

kwargs: Dict[str, Union[Any, bool]] = {}
if isinstance(x, ast.Call) and isinstance(x.func, ast.Name) and x.func.id == "len":
kwargs["len_func"] = True
if isinstance(y, ast.Call) and isinstance(y.func, ast.Name) and y.func.id == "len":
kwargs["len_func"] = True
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(op, ast.Lt):
if a is ARG:
return {"max_value": b, "exclude_max": True}
return {"min_value": a, "exclude_min": True}
kwargs.update({"max_value": b, "exclude_max": True})
else:
kwargs.update({"min_value": a, "exclude_min": True})
elif isinstance(op, ast.LtE):
if a is ARG:
return {"max_value": b}
return {"min_value": a}
kwargs.update({"max_value": b})
else:
kwargs.update({"min_value": a})
elif isinstance(op, ast.Eq):
if a is ARG:
return {"min_value": b, "max_value": b}
return {"min_value": a, "max_value": a}
kwargs.update({"min_value": b, "max_value": b})
else:
kwargs.update({"min_value": a, "max_value": a})
elif isinstance(op, ast.GtE):
if a is ARG:
return {"min_value": b}
return {"max_value": a}
kwargs.update({"min_value": b})
else:
kwargs.update({"max_value": a})
elif isinstance(op, ast.Gt):
if a is ARG:
return {"min_value": b, "exclude_min": True}
return {"max_value": a, "exclude_max": True}
raise ValueError("Unhandled comparison operator") # e.g. ast.Ne
kwargs.update({"min_value": b, "exclude_min": True})
else:
kwargs.update({"max_value": a, "exclude_max": True})
else:
raise ValueError("Unhandled comparison operator") # e.g. ast.Ne
return kwargs


def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate:
Expand All @@ -117,6 +134,7 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
"max_value": math.inf,
"exclude_min": False,
"exclude_max": False,
"len_func": False,
}
predicate = None
for kw, p in con_predicates:
Expand All @@ -133,6 +151,8 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
base["max_value"] = kw["max_value"]
elif kw["max_value"] == base["max_value"]:
base["exclude_max"] |= kw.get("exclude_max", False)
if "len_func" in kw:
base["len_func"] |= kw.get("len_func", False)
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved

if not base["exclude_min"]:
del base["exclude_min"]
Expand All @@ -142,6 +162,8 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate
del base["exclude_max"]
if base["max_value"] == math.inf:
del base["max_value"]
if not base["len_func"]:
del base["len_func"]
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
return ConstructivePredicate(base, predicate)


Expand All @@ -154,6 +176,8 @@ def numeric_bounds_from_ast(
{"min_value": 0}, None
>>> lambda x: x < 10
{"max_value": 10, "exclude_max": True}, None
>>> lambda x: len(x) >= 5
{"min_value": 5, "len_func": True}, None
>>> lambda x: x >= y
{}, lambda x: x >= y

Expand Down Expand Up @@ -270,7 +294,8 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
elif kwargs.get("exclude_max", False):
kwargs["max_value"] = int(kwargs["max_value"]) - 1

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
kw_categories = {"min_value", "max_value", "len_func"}
kwargs = {k: v for k, v in kwargs.items() if k in kw_categories}
return ConstructivePredicate(kwargs, predicate)


Expand Down
28 changes: 27 additions & 1 deletion hypothesis-python/src/hypothesis/strategies/_internal/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import copy
import re
import warnings
from functools import lru_cache
from functools import lru_cache, partial

from hypothesis.errors import HypothesisWarning, InvalidArgument
from hypothesis.internal import charmap
from hypothesis.internal.filtering import get_integer_predicate_bounds, max_len, min_len
from hypothesis.internal.intervalsets import IntervalSet
from hypothesis.strategies._internal.collections import ListStrategy
from hypothesis.strategies._internal.lazy import unwrap_strategies
Expand Down Expand Up @@ -142,7 +143,32 @@ def filter(self, condition):
HypothesisWarning,
stacklevel=2,
)

elems = unwrap_strategies(self.element_strategy)

kwargs, pred = get_integer_predicate_bounds(condition)

Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
min_value, max_value = None, None
if "len_func" in kwargs and kwargs["len_func"]:
min_value = kwargs.get("min_value")
max_value = kwargs.get("max_value")
if isinstance(condition, partial) and len(condition.args) == 1:
min_value = condition.args[0] if condition.func is min_len else None
max_value = condition.args[0] if condition.func is max_len else None
if min_value is not None or max_value is not None:
self.min_size = (
max(self.min_size, min_value)
if min_value is not None
else self.min_size
)
self.max_size = (
min(self.max_size, max_value)
if max_value is not None
else self.max_size
)
if isinstance(condition, partial):
return self

if (
condition is str.isidentifier
and self.max_size >= 1
Expand Down
71 changes: 71 additions & 0 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from hypothesis import given, strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.filtering import max_len, min_len
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies._internal.core import data
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy
from hypothesis.strategies._internal.strings import TextStrategy

from tests.common.utils import fails_with

Expand Down Expand Up @@ -383,3 +385,72 @@ def test_isidentifer_filter_unsatisfiable(al):
def test_filter_floats_can_skip_subnormals(op, attr, value, expected):
base = st.floats(allow_subnormal=False).filter(partial(op, value))
assert getattr(base.wrapped_strategy, attr) == expected


@pytest.mark.parametrize(
"strategy, predicate, start, end",
[
# text with integer bounds
(st.text(min_size=1, max_size=5), partial(min_len, 3), 3, 5),
(st.text(min_size=1, max_size=5), partial(max_len, 3), 1, 3),
# text with only one bound
(st.text(min_size=1), partial(min_len, 3), 3, math.inf),
(st.text(min_size=1), partial(max_len, 3), 1, 3),
(st.text(max_size=5), partial(min_len, 3), 3, 5),
(st.text(max_size=5), partial(max_len, 3), 0, 3),
# Unbounded text
(st.text(), partial(min_len, 3), 3, math.inf),
(st.text(), partial(max_len, 3), 0, 3),
],
)
@given(data=st.data())
def test_filter_rewriting_text_partial_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)

assert isinstance(s, LazyStrategy)
assert isinstance(s.wrapped_strategy, TextStrategy)
assert s.wrapped_strategy.min_size == start
assert s.wrapped_strategy.max_size == end
value = data.draw(s)
assert predicate(value)


@pytest.mark.parametrize(
"strategy, predicate, start, end",
[
# Simple lambdas
(st.text(), lambda x: len(x) < 3, 0, 2),
(st.text(), lambda x: len(x) <= 3, 0, 3),
(st.text(), lambda x: len(x) == 3, 3, 3),
(st.text(), lambda x: len(x) >= 3, 3, math.inf),
(st.text(), lambda x: len(x) > 3, 4, math.inf),
# Simple lambdas, reverse comparison
(st.text(), lambda x: 3 > len(x), 0, 2),
(st.text(), lambda x: 3 >= len(x), 0, 3),
(st.text(), lambda x: 3 == len(x), 3, 3),
(st.text(), lambda x: 3 <= len(x), 3, math.inf),
(st.text(), lambda x: 3 < len(x), 4, math.inf),
# More complicated lambdas
(st.text(), lambda x: 0 < len(x) < 5, 1, 4),
(st.text(), lambda x: 0 < len(x) >= 1, 1, math.inf),
(st.text(), lambda x: 1 > len(x) <= 0, 0, 0),
(st.text(), lambda x: len(x) > 0 and len(x) > 0, 1, math.inf),
(st.text(), lambda x: len(x) < 1 and len(x) < 1, 0, 0),
(st.text(), lambda x: len(x) > 1 and len(x) > 0, 2, math.inf),
(st.text(), lambda x: len(x) < 1 and len(x) < 2, 0, 0),
],
ids=get_pretty_function_description,
)
@given(data=st.data())
def test_filter_rewriting_text_lambda_len(data, strategy, predicate, start, end):
s = strategy.filter(predicate)
unwrapped = s.wrapped_strategy
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, TextStrategy)
for pred in unwrapped.flat_conditions:
assert pred.__name__ == "<lambda>"

assert unwrapped.filtered_strategy.min_size == start
assert unwrapped.filtered_strategy.max_size == end
value = data.draw(s)
assert predicate(value)