Skip to content

Commit

Permalink
specified_kwargs() strategy to test default kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Oct 27, 2021
1 parent 7fe9c96 commit 1b30535
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Union
from typing import Union, Any, Tuple
from itertools import takewhile, count

from hypothesis import assume, given, strategies as st
Expand All @@ -13,6 +13,21 @@
from .typing import Shape, DataType, Array, Scalar


@st.composite
def specified_kwargs(draw, *keys_values_defaults: Tuple[str, Any, Any]):
"""Generates valid kwargs given expected defaults.
When we can't realistically use hh.kwargs() and thus test whether xp infact
defaults correctly, this strategy lets us remove generated arguments if they
are of the default value anyway.
"""
kw = {}
for key, value, default in keys_values_defaults:
if value is not default or draw(st.booleans()):
kw[key] = value
return kw


def assert_default_float(func_name: str, dtype: DataType):
f_dtype = dh.dtype_to_name[dtype]
f_default = dh.dtype_to_name[dh.default_float]
Expand Down Expand Up @@ -168,7 +183,15 @@ def test_arange(dtype, data):
size <= hh.MAX_ARRAY_SIZE
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check

out = xp.arange(start, stop=stop, step=step, dtype=dtype)
kw = data.draw(
specified_kwargs(
("stop", stop, None),
("step", step, None),
("dtype", dtype, None),
),
label="kw",
)
out = xp.arange(start, **kw)

if dtype is None:
if all_int:
Expand Down Expand Up @@ -356,15 +379,22 @@ def test_linspace(num, dtype, endpoint, data):
m, M = dh.dtype_ranges[_dtype]
stop = data.draw(int_stops(start, min_gap, m, M), label="stop")

out = xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
kw = data.draw(
specified_kwargs(
("dtype", dtype, None),
("endpoint", endpoint, True),
),
label="kw",
)
out = xp.linspace(start, stop, num, **kw)

assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)

if endpoint:
if num > 1:
assert ah.equal(
out[-1], ah.asarray(stop, dtype=out.dtype)
), f"out[-1]={out[-1]}, but should be {stop=} [linspace()]"
), f"out[-1]={out[-1]}, but should be {stop=} [linspace({start=}, {num=})]"
else:
# linspace(..., num, endpoint=True) should return an array equivalent to
# the first num elements when endpoint=False
Expand All @@ -375,8 +405,9 @@ def test_linspace(num, dtype, endpoint, data):
if num > 0:
assert ah.equal(
out[0], ah.asarray(start, dtype=out.dtype)
), f"out[0]={out[0]}, but should be {start=} [linspace()]"
# TODO: array assertions ala test_arange
), f"out[0]={out[0]}, but should be {start=} [linspace({stop=}, {num=})]"

# TODO: array assertions ala test_arange


def make_one(dtype: DataType) -> Scalar:
Expand Down

0 comments on commit 1b30535

Please sign in to comment.