Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter-rewriting for st.floats() #3385

Merged
merged 2 commits into from Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,10 @@
RELEASE_TYPE: patch

This release automatically rewrites some simple filters, such as
``floats().filter(lambda x: x >= 10)`` to the more efficient
``floats(min_value=10)``, based on the AST of the predicate.

We continue to recommend using the efficient form directly wherever
possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``"
where you already have a simple predicate and translating manually
is really annoying. See :issue:`2701` for details.
24 changes: 24 additions & 0 deletions hypothesis-python/src/hypothesis/internal/filtering.py
Expand Up @@ -32,6 +32,7 @@
from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar

from hypothesis.internal.compat import ceil, floor
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import extract_lambda_source

Ex = TypeVar("Ex")
Expand Down Expand Up @@ -274,3 +275,26 @@ def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
return ConstructivePredicate(kwargs, predicate)


def get_float_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
kwargs, predicate = get_numeric_predicate_bounds(predicate) # type: ignore

if "min_value" in kwargs:
min_value = kwargs["min_value"]
kwargs["min_value"] = float(kwargs["min_value"])
if min_value < kwargs["min_value"] or (
min_value == kwargs["min_value"] and kwargs.get("exclude_min", False)
):
kwargs["min_value"] = next_up(kwargs["min_value"])

if "max_value" in kwargs:
max_value = kwargs["max_value"]
kwargs["max_value"] = float(kwargs["max_value"])
if max_value > kwargs["max_value"] or (
max_value == kwargs["max_value"] and kwargs.get("exclude_max", False)
):
kwargs["max_value"] = next_down(kwargs["max_value"])

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
return ConstructivePredicate(kwargs, predicate)
55 changes: 47 additions & 8 deletions hypothesis-python/src/hypothesis/strategies/_internal/numbers.py
Expand Up @@ -18,7 +18,10 @@
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture import floats as flt, utils as d
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.filtering import get_integer_predicate_bounds
from hypothesis.internal.filtering import (
get_float_predicate_bounds,
get_integer_predicate_bounds,
)
from hypothesis.internal.floats import (
float_of,
int_to_float,
Expand Down Expand Up @@ -195,9 +198,10 @@ class FloatStrategy(SearchStrategy):

def __init__(
self,
min_value: float = -math.inf,
max_value: float = math.inf,
allow_nan: bool = True,
*,
min_value: float,
max_value: float,
allow_nan: bool,
# The smallest nonzero number we can represent is usually a subnormal, but may
# be the smallest normal if we're running in unsafe denormals-are-zero mode.
# While that's usually an explicit error, we do need to handle the case where
Expand Down Expand Up @@ -292,6 +296,40 @@ def do_draw(self, data):
data.stop_example() # (FLOAT_STRATEGY_DO_DRAW_LABEL)
return result

def filter(self, condition):
kwargs, pred = get_float_predicate_bounds(condition)
if not kwargs:
return super().filter(pred)
min_bound = max(kwargs.get("min_value", -math.inf), self.min_value)
max_bound = min(kwargs.get("max_value", math.inf), self.max_value)

# Adjustments for allow_subnormal=False, if any need to be made
if -self.smallest_nonzero_magnitude < min_bound < 0:
min_bound = -0.0
elif 0 < min_bound < self.smallest_nonzero_magnitude:
min_bound = self.smallest_nonzero_magnitude
if -self.smallest_nonzero_magnitude < max_bound < 0:
max_bound = -self.smallest_nonzero_magnitude
elif 0 < max_bound < self.smallest_nonzero_magnitude:
max_bound = 0.0

if min_bound > max_bound:
return nothing()
if (
min_bound > self.min_value
or self.max_value > max_bound
or (self.allow_nan and (-math.inf < min_bound or max_bound < math.inf))
):
self = type(self)(
min_value=min_bound,
max_value=max_bound,
allow_nan=False,
smallest_nonzero_magnitude=self.smallest_nonzero_magnitude,
)
if pred is None:
return self
return super().filter(pred)


@cacheable
@defines_strategy(force_reusable_values=True)
Expand Down Expand Up @@ -508,14 +546,17 @@ def floats(
min_value = float("-inf")
if max_value is None:
max_value = float("inf")
if not allow_infinity:
min_value = max(min_value, next_up(float("-inf")))
max_value = min(max_value, next_down(float("inf")))
assert isinstance(min_value, float)
assert isinstance(max_value, float)
smallest_nonzero_magnitude = (
SMALLEST_SUBNORMAL if allow_subnormal else smallest_normal
)
result: SearchStrategy = FloatStrategy(
min_value,
max_value,
min_value=min_value,
max_value=max_value,
allow_nan=allow_nan,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
)
Expand All @@ -529,6 +570,4 @@ def downcast(x):
reject()

result = result.map(downcast)
if not allow_infinity:
result = result.filter(lambda x: not math.isinf(x))
return result
116 changes: 94 additions & 22 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
Expand Up @@ -11,15 +11,18 @@
import decimal
import math
import operator
from fractions import Fraction
from functools import partial
from sys import float_info

import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import HypothesisWarning, Unsatisfiable
from hypothesis.internal.floats import next_down, next_up
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies
from hypothesis.strategies._internal.numbers import IntegersStrategy
from hypothesis.strategies._internal.numbers import FloatStrategy, IntegersStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy

from tests.common.utils import fails_with
Expand Down Expand Up @@ -87,20 +90,81 @@ def test_filter_rewriting(data, strategy, predicate, start, end):


@pytest.mark.parametrize(
"s",
"strategy, predicate, min_value, max_value",
[
st.integers(1, 5).filter(partial(operator.lt, 6)),
st.integers(1, 5).filter(partial(operator.eq, 3.5)),
st.integers(1, 5).filter(partial(operator.eq, "can't compare to strings")),
st.integers(1, 5).filter(partial(operator.ge, 0)),
st.integers(1, 5).filter(partial(operator.lt, math.inf)),
st.integers(1, 5).filter(partial(operator.gt, -math.inf)),
# Floats with integer bounds
(st.floats(1, 5), partial(operator.lt, 3), next_up(3.0), 5), # 3 < x
(st.floats(1, 5), partial(operator.le, 3), 3, 5), # lambda x: 3 <= x
(st.floats(1, 5), partial(operator.eq, 3), 3, 3), # lambda x: 3 == x
(st.floats(1, 5), partial(operator.ge, 3), 1, 3), # lambda x: 3 >= x
(st.floats(1, 5), partial(operator.gt, 3), 1, next_down(3.0)), # 3 > x
# Floats with non-integer bounds
(st.floats(1, 5), partial(operator.lt, 3.5), next_up(3.5), 5),
(st.floats(1, 5), partial(operator.le, 3.5), 3.5, 5),
(st.floats(1, 5), partial(operator.ge, 3.5), 1, 3.5),
(st.floats(1, 5), partial(operator.gt, 3.5), 1, next_down(3.5)),
(st.floats(1, 5), partial(operator.lt, -math.inf), 1, 5),
(st.floats(1, 5), partial(operator.gt, math.inf), 1, 5),
# Floats with only one bound
(st.floats(min_value=1), partial(operator.lt, 3), next_up(3.0), math.inf),
(st.floats(min_value=1), partial(operator.le, 3), 3, math.inf),
(st.floats(max_value=5), partial(operator.ge, 3), -math.inf, 3),
(st.floats(max_value=5), partial(operator.gt, 3), -math.inf, next_down(3.0)),
# Unbounded floats
(st.floats(), partial(operator.lt, 3), next_up(3.0), math.inf),
(st.floats(), partial(operator.le, 3), 3, math.inf),
(st.floats(), partial(operator.eq, 3), 3, 3),
(st.floats(), partial(operator.ge, 3), -math.inf, 3),
(st.floats(), partial(operator.gt, 3), -math.inf, next_down(3.0)),
# Simple lambdas
(st.floats(), lambda x: x < 3, -math.inf, next_down(3.0)),
(st.floats(), lambda x: x <= 3, -math.inf, 3),
(st.floats(), lambda x: x == 3, 3, 3),
(st.floats(), lambda x: x >= 3, 3, math.inf),
(st.floats(), lambda x: x > 3, next_up(3.0), math.inf),
# Simple lambdas, reverse comparison
(st.floats(), lambda x: 3 > x, -math.inf, next_down(3.0)),
(st.floats(), lambda x: 3 >= x, -math.inf, 3),
(st.floats(), lambda x: 3 == x, 3, 3),
(st.floats(), lambda x: 3 <= x, 3, math.inf),
(st.floats(), lambda x: 3 < x, next_up(3.0), math.inf),
# More complicated lambdas
(st.floats(), lambda x: 0 < x < 5, next_up(0.0), next_down(5.0)),
(st.floats(), lambda x: 0 < x >= 1, 1, math.inf),
(st.floats(), lambda x: 1 > x <= 0, -math.inf, 0),
(st.floats(), lambda x: x > 0 and x > 0, next_up(0.0), math.inf),
(st.floats(), lambda x: x < 1 and x < 1, -math.inf, next_down(1.0)),
(st.floats(), lambda x: x > 1 and x > 0, next_up(1.0), math.inf),
(st.floats(), lambda x: x < 1 and x < 2, -math.inf, next_down(1.0)),
],
ids=get_pretty_function_description,
)
@fails_with(Unsatisfiable)
@given(data=st.data())
def test_rewrite_unsatisfiable_filter(data, s):
data.draw(s)
def test_filter_rewriting_floats(data, strategy, predicate, min_value, max_value):
s = strategy.filter(predicate)
assert isinstance(s, LazyStrategy)
assert isinstance(s.wrapped_strategy, FloatStrategy)
assert s.wrapped_strategy.min_value == min_value
assert s.wrapped_strategy.max_value == max_value
value = data.draw(s)
assert predicate(value)


@pytest.mark.parametrize(
"pred",
[
partial(operator.lt, 6),
partial(operator.eq, Fraction(10, 3)),
partial(operator.eq, "can't compare to strings"),
partial(operator.ge, 0),
partial(operator.lt, math.inf),
partial(operator.gt, -math.inf),
],
)
@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)])
@fails_with(Unsatisfiable)
def test_rewrite_unsatisfiable_filter(s, pred):
s.filter(pred).example()


@given(st.integers(0, 2).filter(partial(operator.ne, 1)))
Expand All @@ -115,14 +179,8 @@ def test_rewriting_does_not_compare_decimal_snan():
s.example()


@pytest.mark.parametrize(
"strategy, lo, hi",
[
(st.integers(0, 1), -1, 2),
],
ids=repr,
)
def test_applying_noop_filter_returns_self(strategy, lo, hi):
@pytest.mark.parametrize("strategy", [st.integers(0, 1), st.floats(0, 1)], ids=repr)
def test_applying_noop_filter_returns_self(strategy):
s = strategy.wrapped_strategy
s2 = s.filter(partial(operator.le, -1)).filter(partial(operator.ge, 2))
assert s is s2
Expand All @@ -135,6 +193,7 @@ def mod2(x):
Y = 2**20


@pytest.mark.parametrize("s", [st.integers(1, 5), st.floats(1, 5)])
@given(
data=st.data(),
predicates=st.permutations(
Expand All @@ -149,9 +208,8 @@ def mod2(x):
]
),
)
def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
def test_rewrite_filter_chains_with_some_unhandled(data, predicates, s):
# Set up our strategy
s = st.integers(1, 5)
for p in predicates:
s = s.filter(p)

Expand All @@ -163,7 +221,7 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
# No matter the order of the filters, we get the same resulting structure
unwrapped = s.wrapped_strategy
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, IntegersStrategy)
assert isinstance(unwrapped.filtered_strategy, (IntegersStrategy, FloatStrategy))
for pred in unwrapped.flat_conditions:
assert pred is mod2 or pred.__name__ == "<lambda>"

Expand Down Expand Up @@ -246,3 +304,17 @@ def test_bumps_min_size_and_filters_for_content_str_methods(method):
fs = s.filter(method)
assert fs.filtered_strategy.min_size == 1
assert fs.flat_conditions == (method,)


@pytest.mark.parametrize(
"op, attr, value, expected",
[
(operator.lt, "min_value", -float_info.min / 2, 0),
(operator.lt, "min_value", float_info.min / 2, float_info.min),
(operator.gt, "max_value", float_info.min / 2, 0),
(operator.gt, "max_value", -float_info.min / 2, -float_info.min),
],
)
def test_filter_floats_can_skip_subnormals(op, attr, value, expected):
base = st.floats(allow_subnormal=False).filter(partial(op, value))
assert getattr(base.wrapped_strategy, attr) == expected