Skip to content

Commit

Permalink
add advanced integer indexing strategy and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rsokl authored and Zac-HD committed Jul 4, 2019
1 parent a943d18 commit 20e5c43
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 7 deletions.
6 changes: 6 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,6 @@
RELEASE_TYPE: patch

This release adds the strategy :func:`~hypothesis.extra.numpy.integer_array_indices`,
which generates tuples of Numpy arrays that can be used for
`advanced indexing <http://www.pythonlikeyoumeanit.com/Module3_IntroducingNumpy/AdvancedIndexing.html#Integer-Array-Indexing>`_
to select an array of a specified shape.
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/_strategies.py
Expand Up @@ -1318,8 +1318,8 @@ def everything_except(excluded_types):
For example, ``everything_except(int)`` returns a strategy that can
generate anything that ``from_type()`` can ever generate, except for
instances of :class:python:int, and excluding instances of types
added via :func:~hypothesis.strategies.register_type_strategy.
instances of :class:`python:int`, and excluding instances of types
added via :func:`~hypothesis.strategies.register_type_strategy`.
This is useful when writing tests which check that invalid input is
rejected in a certain way.
Expand Down
75 changes: 71 additions & 4 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Expand Up @@ -37,6 +37,8 @@
from typing import Any, Union, Sequence, Tuple, Optional # noqa
from hypothesis.searchstrategy.strategies import T # noqa

Shape = Tuple[int, ...] # noqa

TIME_RESOLUTIONS = tuple("Y M D h m s ms us ns ps fs as".split())


Expand Down Expand Up @@ -295,7 +297,7 @@ def fill_for(elements, unique, fill, name=""):
@st.defines_strategy
def arrays(
dtype, # type: Any
shape, # type: Union[int, Sequence[int], st.SearchStrategy[Sequence[int]]]
shape, # type: Union[int, Shape, st.SearchStrategy[Shape]]
elements=None, # type: st.SearchStrategy[Any]
fill=None, # type: st.SearchStrategy[Any]
unique=False, # type: bool
Expand Down Expand Up @@ -401,7 +403,7 @@ def arrays(

@st.defines_strategy
def array_shapes(min_dims=1, max_dims=None, min_side=1, max_side=None):
# type: (int, int, int, int) -> st.SearchStrategy[Tuple[int, ...]]
# type: (int, int, int, int) -> st.SearchStrategy[Shape]
"""Return a strategy for array shapes (tuples of int >= 1)."""
check_type(integer_types, min_dims, "min_dims")
check_type(integer_types, min_side, "min_side")
Expand Down Expand Up @@ -672,7 +674,7 @@ def nested_dtypes(

@st.defines_strategy
def valid_tuple_axes(ndim, min_size=0, max_size=None):
# type: (int, int, int) -> st.SearchStrategy[Tuple[int, ...]]
# type: (int, int, int) -> st.SearchStrategy[Shape]
"""Return a strategy for generating permissible tuple-values for the
``axis`` argument for a numpy sequential function (e.g.
:func:`numpy:numpy.sum`), given an array of the specified
Expand Down Expand Up @@ -763,7 +765,7 @@ def do_draw(self, data):

@st.defines_strategy
def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=None):
# type: (Sequence[int], int, Optional[int], int, Optional[int]) -> st.SearchStrategy[Tuple[int, ...]]
# type: (Shape, int, int, int, int) -> st.SearchStrategy[Shape]
"""Return a strategy for generating shapes that are broadcast-compatible
with the provided shape.
Expand Down Expand Up @@ -846,3 +848,68 @@ def broadcastable_shapes(shape, min_dims=0, max_dims=None, min_side=1, max_side=
min_side=min_side,
max_side=max_side,
)


@st.defines_strategy
def integer_array_indices(shape, result_shape=array_shapes(), dtype="int"):
# type: (Shape, SearchStrategy[Shape], np.dtype) -> st.SearchStrategy[Tuple[np.ndarray, ...]]
"""Return a search strategy for tuples of integer-arrays that, when used
to index into an array of shape ``shape``, given an array whose shape
was drawn from ``result_shape``.
Examples from this strategy shrink towards the tuple of index-arrays::
len(shape) * (np.zeros(drawn_result_shape, dtype), )
* ``shape`` a tuple of integers that indicates the shape of the array,
whose indices are being generated.
* ``result_shape`` a strategy for generating tuples of integers, which
describe the shape of the resulting index arrays. The default is
:func:`~hypothesis.extra.numpy.array_shapes`. The shape drawn from
this strategy determines the shape of the array that will be produced
when the corresponding example from ``integer_array_indices`` is used
as an index.
* ``dtype`` the integer data type of the generated index-arrays. Negative
integer indices can be generated if a signed integer type is specified.
Recall that an array can be indexed using a tuple of integer-arrays to
access its members in an arbitrary order, producing an array with an
arbitrary shape. For example:
.. code-block:: pycon
>>> from numpy import array
>>> x = array([-0, -1, -2, -3, -4])
>>> ind = (array([[4, 0], [0, 1]]),) # a tuple containing a 2D integer-array
>>> x[ind] # the resulting array is commensurate with the indexing array(s)
array([[-4, 0],
[0, -1]])
Note that this strategy does not accommodate all variations of so-called
'advanced indexing', as prescribed by NumPy's nomenclature. Combinations
of basic and advanced indexes are too complex to usefully define in a
standard strategy; we leave application-specific strategies to the user.
Advanced-boolean indexing can be defined as ``arrays(shape=..., dtype=bool)``,
and is similarly left to the user.
"""
check_type(tuple, shape, "shape")
check_argument(
shape and all(isinstance(x, integer_types) and x > 0 for x in shape),
"shape=%r must be a non-empty tuple of integers > 0" % (shape,),
)
check_type(SearchStrategy, result_shape, "result_shape")
check_argument(
np.issubdtype(dtype, np.integer), "dtype=%r must be an integer dtype" % (dtype,)
)
signed = np.issubdtype(dtype, np.signedinteger)

def array_for(index_shape, size):
return arrays(
dtype=dtype,
shape=index_shape,
elements=st.integers(-size if signed else 0, size - 1),
)

return result_shape.flatmap(
lambda index_shape: st.tuples(*[array_for(index_shape, size) for size in shape])
)
5 changes: 5 additions & 0 deletions hypothesis-python/tests/numpy/test_argument_validation.py
Expand Up @@ -116,6 +116,11 @@ def e(a, **kwargs):
min_side=2,
max_side=3,
),
e(nps.integer_array_indices, shape=()),
e(nps.integer_array_indices, shape=(2, 0)),
e(nps.integer_array_indices, shape="a"),
e(nps.integer_array_indices, shape=(2,), result_shape=(2, 2)),
e(nps.integer_array_indices, shape=(2,), dtype=float),
],
)
def test_raise_invalid_argument(function, kwargs):
Expand Down
98 changes: 98 additions & 0 deletions hypothesis-python/tests/numpy/test_gen_data.py
Expand Up @@ -702,3 +702,101 @@ def test_broadcastable_shape_can_generate_arbitrary_ndims(shape, max_dims, data)
lambda x: len(x) == desired_ndim,
settings(max_examples=10 ** 6),
)


@settings(deadline=None)
@given(
shape=nps.array_shapes(min_dims=1, min_side=1),
dtype=st.one_of(nps.unsigned_integer_dtypes(), nps.integer_dtypes()),
data=st.data(),
)
def test_advanced_integer_index_is_valid_with_default_result_shape(shape, dtype, data):
index = data.draw(nps.integer_array_indices(shape, dtype=dtype))
x = np.zeros(shape)
out = x[index] # raises if the index is invalid
assert not np.shares_memory(x, out) # advanced indexing should not return a view
assert all(dtype == x.dtype for x in index)


@settings(deadline=None)
@given(
shape=nps.array_shapes(min_dims=1, min_side=1),
min_dims=st.integers(0, 3),
min_side=st.integers(0, 3),
dtype=st.one_of(nps.unsigned_integer_dtypes(), nps.integer_dtypes()),
data=st.data(),
)
def test_advanced_integer_index_is_valid_and_satisfies_bounds(
shape, min_dims, min_side, dtype, data
):
max_side = data.draw(st.integers(min_side, min_side + 2), label="max_side")
max_dims = data.draw(st.integers(min_dims, min_dims + 2), label="max_dims")
index = data.draw(
nps.integer_array_indices(
shape,
result_shape=nps.array_shapes(
min_dims=min_dims,
max_dims=max_dims,
min_side=min_side,
max_side=max_side,
),
dtype=dtype,
)
)
x = np.zeros(shape)
out = x[index] # raises if the index is invalid
assert all(min_side <= s <= max_side for s in out.shape)
assert min_dims <= out.ndim <= max_dims
assert not np.shares_memory(x, out) # advanced indexing should not return a view
assert all(dtype == x.dtype for x in index)


@settings(deadline=None)
@given(
shape=nps.array_shapes(min_dims=1, min_side=1),
min_dims=st.integers(0, 3),
min_side=st.integers(0, 3),
dtype=st.sampled_from(["uint8", "int8"]),
data=st.data(),
)
def test_advanced_integer_index_minimizes_as_documented(
shape, min_dims, min_side, dtype, data
):
max_side = data.draw(st.integers(min_side, min_side + 2), label="max_side")
max_dims = data.draw(st.integers(min_dims, min_dims + 2), label="max_dims")
result_shape = nps.array_shapes(
min_dims=min_dims, max_dims=max_dims, min_side=min_side, max_side=max_side
)
smallest = minimal(
nps.integer_array_indices(shape, result_shape=result_shape, dtype=dtype)
)
desired = len(shape) * (np.zeros(min_dims * [min_side]),)
assert len(smallest) == len(desired)
for s, d in zip(smallest, desired):
np.testing.assert_array_equal(s, d)


@settings(deadline=None, max_examples=10)
@given(
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=3),
data=st.data(),
)
def test_advanced_integer_index_can_generate_any_pattern(shape, data):
# ensures that generated index-arrays can be used to yield any pattern of elements from an array
x = np.arange(np.product(shape)).reshape(shape)

target = data.draw(
nps.arrays(
shape=nps.array_shapes(min_dims=1, max_dims=2, min_side=1, max_side=2),
elements=st.sampled_from(x.flatten()),
dtype=x.dtype,
),
label="target",
)
find_any(
nps.integer_array_indices(
shape, result_shape=st.just(target.shape), dtype=np.dtype("int8")
),
lambda index: np.all(target == x[index]),
settings(max_examples=10 ** 6),
)
2 changes: 1 addition & 1 deletion tooling/src/hypothesistooling/projects/hypothesispython.py
Expand Up @@ -193,7 +193,7 @@ def upload_distribution():
entries = [i for i, l in enumerate(lines) if CHANGELOG_HEADER.match(l)]
changelog_body = "".join(lines[entries[0] + 2 : entries[1]]).strip() + (
"\n\n*[The canonical version of these notes (with links) is on readthedocs.]"
"(https://hypothesis.readthedocs.io/en/latest/changes.html#v%s).*"
"(https://hypothesis.readthedocs.io/en/latest/changes.html#v%s)*"
% (current_version().replace(".", "-"),)
)

Expand Down

0 comments on commit 20e5c43

Please sign in to comment.