Skip to content

Commit

Permalink
Merge pull request #32 from honno/creation-refactor
Browse files Browse the repository at this point in the history
Refactor assertions in `test_creation.py`
  • Loading branch information
asmeurer committed Oct 29, 2021
2 parents 797537e + ca2ef81 commit 035e3f3
Show file tree
Hide file tree
Showing 8 changed files with 678 additions and 365 deletions.
48 changes: 34 additions & 14 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import itertools
from functools import reduce
from operator import mul
from math import sqrt
import itertools
from typing import Tuple, Optional, List
from operator import mul
from typing import Any, List, NamedTuple, Optional, Tuple

from hypothesis import assume
from hypothesis.strategies import (lists, integers, sampled_from,
shared, floats, just, composite, one_of,
none, booleans, SearchStrategy)
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
integers, just, lists, none, one_of,
sampled_from, shared)

from .pytest_helpers import nargs
from .array_helpers import ndindex
from .typing import DataType, Shape
from . import dtype_helpers as dh
from ._array_module import (full, float32, float64, bool as bool_dtype,
_UndefinedStub, eye, broadcast_to)
from . import _array_module as xp
from . import dtype_helpers as dh
from . import xps

from ._array_module import _UndefinedStub
from ._array_module import bool as bool_dtype
from ._array_module import broadcast_to, eye, float32, float64, full
from .array_helpers import ndindex
from .function_stubs import elementwise_functions

from .pytest_helpers import nargs
from .typing import DataType, Shape

# Set this to True to not fail tests just because a dtype isn't implemented.
# If no compatible dtype is implemented for a given test, the test will fail
Expand Down Expand Up @@ -382,3 +381,24 @@ def test_f(x, kw):
if draw(booleans()):
result[k] = draw(strat)
return result


class KVD(NamedTuple):
keyword: str
value: Any
default: Any


@composite
def specified_kwargs(draw, *keys_values_defaults: KVD):
"""Generates valid kwargs given expected defaults.
When we can't realistically use hh.kwargs() and thus test whether xp infact
defaults correctly, this strategy lets us remove generated arguments if they
are of the default value anyway.
"""
kw = {}
for keyword, value, default in keys_values_defaults:
if value is not default or draw(booleans()):
kw[keyword] = value
return kw
32 changes: 32 additions & 0 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hypothesis import given, strategies as st, settings

from .. import _array_module as xp
from .. import xps
from .._array_module import _UndefinedStub
from .. import array_helpers as ah
from .. import dtype_helpers as dh
Expand Down Expand Up @@ -76,6 +77,37 @@ def run(kw):
assert len(c_results) > 0
assert all(isinstance(kw["c"], str) for kw in c_results)


def test_specified_kwargs():
results = []

@given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data())
@settings(max_examples=100)
def run(n, d, data):
kw = data.draw(
hh.specified_kwargs(
hh.KVD("n", n, 0),
hh.KVD("d", d, None),
),
label="kw",
)
results.append(kw)
run()

assert all(isinstance(kw, dict) for kw in results)

assert any(len(kw) == 0 for kw in results)

assert any("n" not in kw.keys() for kw in results)
assert any("n" in kw.keys() and kw["n"] == 0 for kw in results)
assert any("n" in kw.keys() and kw["n"] != 0 for kw in results)

assert any("d" not in kw.keys() for kw in results)
assert any("d" in kw.keys() and kw["d"] is None for kw in results)
assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)



@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
finite=st.shared(st.booleans(), key='finite')),
dtype=hh.shared_floating_dtypes,
Expand Down
21 changes: 19 additions & 2 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
import pytest

from ..test_signatures import extension_module
from ..test_creation_functions import frange


def test_extension_module_is_extension():
assert extension_module('linalg')
assert extension_module("linalg")


def test_extension_func_is_not_extension():
assert not extension_module('linalg.cross')
assert not extension_module("linalg.cross")


@pytest.mark.parametrize(
"r, size, elements",
[
(frange(0, 1, 1), 1, [0]),
(frange(1, 0, -1), 1, [1]),
(frange(0, 1, -1), 0, []),
(frange(0, 1, 2), 1, [0]),
],
)
def test_frange(r, size, elements):
assert len(r) == size
assert list(r) == elements
84 changes: 79 additions & 5 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
import math
from inspect import getfullargspec
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

from . import array_helpers as ah
from . import dtype_helpers as dh
from . import function_stubs
from .typing import DataType
from .typing import Array, DataType, Scalar, Shape

__all__ = [
"raises",
"doesnt_raise",
"nargs",
"fmt_kw",
"assert_dtype",
"assert_kw_dtype",
"assert_default_float",
"assert_default_int",
"assert_shape",
"assert_fill",
]

def raises(exceptions, function, message=''):

def raises(exceptions, function, message=""):
"""
Like pytest.raises() except it allows custom error messages
"""
Expand All @@ -16,11 +31,14 @@ def raises(exceptions, function, message=''):
return
except Exception as e:
if message:
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions}): {message}")
raise AssertionError(
f"Unexpected exception {e!r} (expected {exceptions}): {message}"
)
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})")
raise AssertionError(message)

def doesnt_raise(function, message=''):

def doesnt_raise(function, message=""):
"""
The inverse of raises().
Expand All @@ -36,10 +54,15 @@ def doesnt_raise(function, message=''):
raise AssertionError(f"Unexpected exception {e!r}: {message}")
raise AssertionError(f"Unexpected exception {e!r}")


def nargs(func_name):
return len(getfullargspec(getattr(function_stubs, func_name)).args)


def fmt_kw(kw: Dict[str, Any]) -> str:
return ", ".join(f"{k}={v}" for k, v in kw.items())


def assert_dtype(
func_name: str,
in_dtypes: Tuple[DataType, ...],
Expand All @@ -60,3 +83,54 @@ def assert_dtype(
assert out_dtype == expected, msg


def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
f_kw_dtype = dh.dtype_to_name[kw_dtype]
f_out_dtype = dh.dtype_to_name[out_dtype]
msg = (
f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} "
f"[{func_name}(dtype={f_kw_dtype})]"
)
assert out_dtype == kw_dtype, msg


def assert_default_float(func_name: str, dtype: DataType):
f_dtype = dh.dtype_to_name[dtype]
f_default = dh.dtype_to_name[dh.default_float]
msg = (
f"out.dtype={f_dtype}, should be default "
f"floating-point dtype {f_default} [{func_name}()]"
)
assert dtype == dh.default_float, msg


def assert_default_int(func_name: str, dtype: DataType):
f_dtype = dh.dtype_to_name[dtype]
f_default = dh.dtype_to_name[dh.default_int]
msg = (
f"out.dtype={f_dtype}, should be default "
f"integer dtype {f_default} [{func_name}()]"
)
assert dtype == dh.default_int, msg


def assert_shape(
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
):
if isinstance(out_shape, int):
out_shape = (out_shape,)
if isinstance(expected, int):
expected = (expected,)
msg = (
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
)
assert out_shape == expected, msg


def assert_fill(
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
):
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
if math.isnan(fill_value):
assert ah.all(ah.isnan(out)), msg
else:
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg

0 comments on commit 035e3f3

Please sign in to comment.