Skip to content

Commit

Permalink
Merge pull request #2031 from HypothesisWorks/DRMacIver/unique-sample…
Browse files Browse the repository at this point in the history
…d-from

Improve performance of unique lists with `elements=sampled_from(...)`
  • Loading branch information
DRMacIver committed Jul 4, 2019
2 parents a943d18 + f52b324 commit 589deb4
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 10 deletions.
8 changes: 8 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,8 @@
RELEASE_TYPE: minor

This release significantly improves the performance of drawing unique collections whose
elements are drawn from :func:`~hypothesis.strategies.sampled_from` strategies.

As a side effect, this detects an error condition that would previously have
passed silently: When the ``min_size`` argument on a collection with distinct elements
is greater than the number of elements being sampled, this will now raise an error.
21 changes: 21 additions & 0 deletions hypothesis-python/src/hypothesis/_strategies.py
Expand Up @@ -87,6 +87,7 @@
ListStrategy,
TupleStrategy,
UniqueListStrategy,
UniqueSampledListStrategy,
)
from hypothesis.searchstrategy.datetime import (
DateStrategy,
Expand Down Expand Up @@ -775,6 +776,26 @@ def unique_by(x):
for i, f in enumerate(unique_by):
if not callable(f):
raise InvalidArgument("unique_by[%i]=%r is not a callable" % (i, f))
# Note that lazy strategies automatically unwrap when passed to a defines_strategy
# function.
if isinstance(elements, SampledFromStrategy):
element_count = len(elements.elements)
if min_size > element_count:
raise InvalidArgument(
"Cannot create a collection of min_size=%r unique elements with "
"values drawn from only %d distinct elements"
% (min_size, element_count)
)

if max_size is not None:
max_size = min(max_size, element_count)
else:
max_size = element_count

return UniqueSampledListStrategy(
elements=elements, max_size=max_size, min_size=min_size, keys=unique_by
)

return UniqueListStrategy(
elements=elements, max_size=max_size, min_size=min_size, keys=unique_by
)
Expand Down
48 changes: 48 additions & 0 deletions hypothesis-python/src/hypothesis/internal/conjecture/junkdrawer.py
Expand Up @@ -173,3 +173,51 @@ def binary_search(lo, hi, f):
def uniform(random, n):
"""Returns an hbytes of length n, distributed uniformly at random."""
return int_to_bytes(random.getrandbits(n * 8), n)


class LazySequenceCopy(object):
"""A "copy" of a sequence that works by inserting a mask in front
of the underlying sequence, so that you can mutate it without changing
the underlying sequence. Effectively behaves as if you could do list(x)
in O(1) time. The full list API is not supported yet but there's no reason
in principle it couldn't be."""

def __init__(self, values):
self.__values = values
self.__len = len(values)
self.__mask = None

def __len__(self):
return self.__len

def pop(self):
if len(self) == 0:
raise IndexError("Cannot pop from empty list")
result = self[-1]
self.__len -= 1
if self.__mask is not None:
self.__mask.pop(self.__len, None)
return result

def __getitem__(self, i):
i = self.__check_index(i)
default = self.__values[i]
if self.__mask is None:
return default
else:
return self.__mask.get(i, default)

def __setitem__(self, i, v):
i = self.__check_index(i)
if self.__mask is None:
self.__mask = {}
self.__mask[i] = v

def __check_index(self, i):
n = len(self)
if i < -n or i >= n:
raise IndexError("Index %d out of range [0, %d)" % (i, n))
if i < 0:
i += n
assert 0 <= i < n
return i
35 changes: 35 additions & 0 deletions hypothesis-python/src/hypothesis/searchstrategy/collections.py
Expand Up @@ -20,6 +20,7 @@
import hypothesis.internal.conjecture.utils as cu
from hypothesis.errors import InvalidArgument
from hypothesis.internal.compat import OrderedDict
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy
from hypothesis.internal.conjecture.utils import combine_labels
from hypothesis.searchstrategy.strategies import (
MappedSearchStrategy,
Expand Down Expand Up @@ -169,6 +170,40 @@ def do_draw(self, data):
return result


class UniqueSampledListStrategy(ListStrategy):
def __init__(self, elements, min_size, max_size, keys):
super(UniqueSampledListStrategy, self).__init__(elements, min_size, max_size)
self.keys = keys

def do_draw(self, data):
should_draw = cu.many(
data,
min_size=self.min_size,
max_size=self.max_size,
average_size=self.average_size,
)
seen_sets = tuple(set() for _ in self.keys)
result = []

remaining = LazySequenceCopy(self.element_strategy.elements)

while should_draw.more():
i = len(remaining) - 1
j = cu.integer_range(data, 0, i)
if j != i:
remaining[i], remaining[j] = remaining[j], remaining[i]
value = remaining.pop()

if all(key(value) not in seen for (key, seen) in zip(self.keys, seen_sets)):
for key, seen in zip(self.keys, seen_sets):
seen.add(key(value))
result.append(value)
else:
should_draw.reject()
assert self.max_size >= len(result) >= self.min_size
return result


class FixedKeysDictStrategy(MappedSearchStrategy):
"""A strategy which produces dicts with a fixed set of keys, given a
strategy for each of their equivalent values.
Expand Down
16 changes: 8 additions & 8 deletions hypothesis-python/src/hypothesis/searchstrategy/lazy.py
Expand Up @@ -80,7 +80,7 @@ def __init__(self, function, args, kwargs):
SearchStrategy.__init__(self)
self.__wrapped_strategy = None
self.__representation = None
self.__function = function
self.function = function
self.__args = args
self.__kwargs = kwargs

Expand Down Expand Up @@ -109,11 +109,11 @@ def wrapped_strategy(self):
k: unwrap_strategies(v) for k, v in self.__kwargs.items()
}

base = self.__function(*self.__args, **self.__kwargs)
base = self.function(*self.__args, **self.__kwargs)
if unwrapped_args == self.__args and unwrapped_kwargs == self.__kwargs:
self.__wrapped_strategy = base
else:
self.__wrapped_strategy = self.__function(
self.__wrapped_strategy = self.function(
*unwrapped_args, **unwrapped_kwargs
)
return self.__wrapped_strategy
Expand All @@ -127,7 +127,7 @@ def __repr__(self):
if self.__representation is None:
_args = self.__args
_kwargs = self.__kwargs
argspec = getfullargspec(self.__function)
argspec = getfullargspec(self.function)
defaults = dict(argspec.kwonlydefaults or {})
if argspec.defaults is not None:
for name, value in zip(
Expand All @@ -136,19 +136,19 @@ def __repr__(self):
defaults[name] = value
if len(argspec.args) > 1 or argspec.defaults:
_args, _kwargs = convert_positional_arguments(
self.__function, _args, _kwargs
self.function, _args, _kwargs
)
else:
_args, _kwargs = convert_keyword_arguments(
self.__function, _args, _kwargs
self.function, _args, _kwargs
)
kwargs_for_repr = dict(_kwargs)
for k, v in defaults.items():
if k in kwargs_for_repr and kwargs_for_repr[k] is defaults[k]:
del kwargs_for_repr[k]
self.__representation = "%s(%s)" % (
self.__function.__name__,
arg_string(self.__function, _args, kwargs_for_repr, reorder=False),
self.function.__name__,
arg_string(self.function, _args, kwargs_for_repr, reorder=False),
)
return self.__representation

Expand Down
56 changes: 56 additions & 0 deletions hypothesis-python/tests/cover/test_conjecture_junkdrawer.py
@@ -0,0 +1,56 @@
# coding=utf-8
#
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Most of this work is copyright (C) 2013-2019 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# CONTRIBUTING.rst for a full list of people who may hold copyright, and
# consult the git log if you need to determine who owns an individual
# contribution.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
#
# END HEADER

from __future__ import absolute_import, division, print_function

import pytest

from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy


def test_out_of_range():
x = LazySequenceCopy([1, 2, 3])

with pytest.raises(IndexError):
x[3]

with pytest.raises(IndexError):
x[-4]


def test_pass_through():
x = LazySequenceCopy([1, 2, 3])
assert x[0] == 1
assert x[1] == 2
assert x[2] == 3


def test_can_assign_without_changing_underlying():
underlying = [1, 2, 3]
x = LazySequenceCopy(underlying)
x[1] = 10
assert x[1] == 10
assert underlying[1] == 2


def test_pop():
x = LazySequenceCopy([2, 3])
assert x.pop() == 3
assert x.pop() == 2

with pytest.raises(IndexError):
x.pop()
1 change: 1 addition & 0 deletions hypothesis-python/tests/cover/test_direct_strategies.py
Expand Up @@ -119,6 +119,7 @@ def fn_ktest(*fnkwargs):
(ds.lists, {"elements": ds.integers(), "unique_by": 1}),
(ds.lists, {"elements": ds.integers(), "unique_by": ()}),
(ds.lists, {"elements": ds.integers(), "unique_by": (1,)}),
(ds.lists, {"elements": ds.sampled_from([0, 1]), "min_size": 3, "unique": True}),
(ds.text, {"min_size": 10, "max_size": 9}),
(ds.text, {"alphabet": [1]}),
(ds.text, {"alphabet": ["abc"]}),
Expand Down
16 changes: 16 additions & 0 deletions hypothesis-python/tests/cover/test_sampled_from.py
Expand Up @@ -20,6 +20,7 @@
import collections
import enum

import hypothesis.strategies as st
from hypothesis import given
from hypothesis.errors import FailedHealthCheck, InvalidArgument, Unsatisfiable
from hypothesis.internal.compat import hrange
Expand Down Expand Up @@ -76,3 +77,18 @@ def test_easy_filtered_sampling():
@given(sampled_from(hrange(100)).filter(lambda x: x == 99))
def test_filtered_sampling_finds_rare_value(x):
assert x == 99


@given(st.sets(st.sampled_from(range(50)), min_size=50))
def test_efficient_sets_of_samples(x):
assert x == set(range(50))


@given(st.lists(st.sampled_from([0] * 100), unique=True))
def test_does_not_include_duplicates_even_when_duplicated_in_collection(ls):
assert len(ls) <= 1


@given(st.lists(st.sampled_from(hrange(100)), max_size=3, unique=True))
def test_max_size_is_respected_with_unique_sampled_from(ls):
assert len(ls) <= 3
4 changes: 2 additions & 2 deletions hypothesis-python/tests/nocover/test_sampled_from.py
Expand Up @@ -21,7 +21,7 @@

import hypothesis.strategies as st
from hypothesis import given
from hypothesis.errors import FailedHealthCheck
from hypothesis.errors import InvalidArgument
from hypothesis.internal.compat import hrange
from tests.common.utils import counts_calls, fails_with

Expand Down Expand Up @@ -64,7 +64,7 @@ def test_chained_filters_find_rare_value(x):
assert x == 80


@fails_with(FailedHealthCheck)
@fails_with(InvalidArgument)
@given(st.sets(st.sampled_from(range(10)), min_size=11))
def test_unsat_sets_of_samples(x):
assert False
Expand Down

0 comments on commit 589deb4

Please sign in to comment.