From beb434ac89c9f132158281eb6c8400e21cf9072f Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Mon, 7 Oct 2019 00:10:32 +1100 Subject: [PATCH] just(x) as a sampled_from strategy While kinda weird on the face of it, this makes generating unique collections a *lot* more efficient in certain unusual cases, and the overhead is roughly nil. --- hypothesis-python/RELEASE.rst | 6 ++ .../src/hypothesis/_strategies.py | 8 +-- .../src/hypothesis/searchstrategy/misc.py | 57 ++++++++----------- hypothesis-python/src/hypothesis/stateful.py | 8 +-- .../tests/cover/test_simple_collections.py | 5 +- .../tests/numpy/test_argument_validation.py | 2 +- 6 files changed, 36 insertions(+), 50 deletions(-) create mode 100644 hypothesis-python/RELEASE.rst diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..0377b7b460 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,6 @@ +RELEASE_TYPE: patch + +This patch improves the performance of unique collections such as +:func:`~hypothesis.strategies.sets` of :func:`~hypothesis.strategies.just` +or :func:`~hypothesis.strategies.boolean` strategies. They were already +pretty good though, so you're unlikely to notice much! diff --git a/hypothesis-python/src/hypothesis/_strategies.py b/hypothesis-python/src/hypothesis/_strategies.py index 0c1338af8b..6c075a08ff 100644 --- a/hypothesis-python/src/hypothesis/_strategies.py +++ b/hypothesis-python/src/hypothesis/_strategies.py @@ -96,11 +96,7 @@ from hypothesis.searchstrategy.deferred import DeferredStrategy from hypothesis.searchstrategy.functions import FunctionStrategy from hypothesis.searchstrategy.lazy import LazyStrategy -from hypothesis.searchstrategy.misc import ( - BoolStrategy, - JustStrategy, - SampledFromStrategy, -) +from hypothesis.searchstrategy.misc import JustStrategy, SampledFromStrategy from hypothesis.searchstrategy.numbers import ( BoundedIntStrategy, FixedBoundedFloatStrategy, @@ -425,7 +421,7 @@ def booleans(): Examples from this strategy will shrink towards False (i.e. shrinking will try to replace True with False where possible). """ - return BoolStrategy() + return sampled_from([False, True]) @cacheable diff --git a/hypothesis-python/src/hypothesis/searchstrategy/misc.py b/hypothesis-python/src/hypothesis/searchstrategy/misc.py index 668acdbd8b..fdd3fd0cc9 100644 --- a/hypothesis-python/src/hypothesis/searchstrategy/misc.py +++ b/hypothesis-python/src/hypothesis/searchstrategy/misc.py @@ -22,20 +22,6 @@ from hypothesis.searchstrategy.strategies import SearchStrategy, filter_not_satisfied -class BoolStrategy(SearchStrategy): - """A strategy that produces Booleans with a Bernoulli conditional - distribution.""" - - def __repr__(self): - return "BoolStrategy()" - - def calc_has_reusable_values(self, recur): - return True - - def do_draw(self, data): - return d.boolean(data) - - def is_simple_data(value): try: hash(value) @@ -44,26 +30,6 @@ def is_simple_data(value): return False -class JustStrategy(SearchStrategy): - """A strategy which always returns a single fixed value.""" - - def __init__(self, value): - SearchStrategy.__init__(self) - self.value = value - - def __repr__(self): - return "just(%r)" % (self.value,) - - def calc_has_reusable_values(self, recur): - return True - - def calc_is_cacheable(self, recur): - return is_simple_data(self.value) - - def do_draw(self, data): - return self.value - - class SampledFromStrategy(SearchStrategy): """A strategy which samples from a set of elements. This is essentially equivalent to using a OneOfStrategy over Just strategies but may be more @@ -154,3 +120,26 @@ def check_index(i): # If there are no allowed indices, the filter couldn't be satisfied. return filter_not_satisfied + + +class JustStrategy(SampledFromStrategy): + """A strategy which always returns a single fixed value. + + It's implemented as a length-one SampledFromStrategy so that all our + special-case logic for filtering and sets applies also to just(x). + """ + + def __init__(self, value): + SampledFromStrategy.__init__(self, [value]) + + def __repr__(self): + return "just(%r)" % (self.elements[0],) + + def calc_has_reusable_values(self, recur): + return True + + def calc_is_cacheable(self, recur): + return is_simple_data(self.elements[0]) + + def do_draw(self, data): + return self.elements[0] diff --git a/hypothesis-python/src/hypothesis/stateful.py b/hypothesis-python/src/hypothesis/stateful.py index 91f9c3b224..aceb9b7f65 100644 --- a/hypothesis-python/src/hypothesis/stateful.py +++ b/hypothesis-python/src/hypothesis/stateful.py @@ -42,7 +42,7 @@ ) from hypothesis.control import current_build_context from hypothesis.core import given -from hypothesis.errors import HypothesisException, InvalidArgument, InvalidDefinition +from hypothesis.errors import InvalidArgument, InvalidDefinition from hypothesis.internal.compat import quiet_raise, string_types from hypothesis.internal.reflection import function_digest, nicerepr, proxies from hypothesis.internal.validation import check_type @@ -606,12 +606,10 @@ def __init__(self, machine): ) def do_draw(self, data): - try: - rule = data.draw(st.sampled_from(self.rules).filter(self.is_valid)) - except HypothesisException: - # FailedHealthCheck or UnsatisfiedAssumption depending on user settings. + if not any(self.is_valid(rule) for rule in self.rules): msg = u"No progress can be made from state %r" % (self.machine,) quiet_raise(InvalidDefinition(msg)) + rule = data.draw(st.sampled_from(self.rules).filter(self.is_valid)) return (rule, data.draw(rule.arguments_strategy)) def is_valid(self, rule): diff --git a/hypothesis-python/tests/cover/test_simple_collections.py b/hypothesis-python/tests/cover/test_simple_collections.py index 7bc8f2192c..0bd7a4bdf0 100644 --- a/hypothesis-python/tests/cover/test_simple_collections.py +++ b/hypothesis-python/tests/cover/test_simple_collections.py @@ -155,10 +155,7 @@ def test_can_draw_empty_set_from_unsatisfiable_strategy(): assert find_any(sets(integers().filter(lambda s: False))) == set() -small_set = sets(none()) - - -@given(lists(small_set, min_size=10)) +@given(lists(sets(none()), min_size=10)) def test_small_sized_sets(x): pass diff --git a/hypothesis-python/tests/numpy/test_argument_validation.py b/hypothesis-python/tests/numpy/test_argument_validation.py index 439aba5224..0a0c592261 100644 --- a/hypothesis-python/tests/numpy/test_argument_validation.py +++ b/hypothesis-python/tests/numpy/test_argument_validation.py @@ -139,7 +139,7 @@ def test_bad_dtype_strategy(capsys, data): s = bad_dtype_strategy() with pytest.raises(ValueError): data.draw(s) - val = s.wrapped_strategy.mapped_strategy.value + val = s.wrapped_strategy.mapped_strategy.elements[0] assert capsys.readouterr().out.startswith( "Got invalid dtype value=%r from strategy=just(%r), function=" % (val, val) )