Skip to content

Commit

Permalink
More flexible size assertions for int arrays in test_arange
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Oct 28, 2021
1 parent 8a7e873 commit abafee1
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,32 @@ def test_arange(dtype, data):
assert out.dtype == dtype
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [linspace()]"
f_func = f"[linspace({start=}, {stop=}, {step=})]"
# We check size is roughly as expected to avoid edge cases e.g.
#
# >>> xp.arange(2, step=0.333333333333333)
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0]
# >>> xp.arange(2, step=0.3333333333333333)
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
#
# >>> start, stop, step = 0, 108086391056891901, 1080863910568919
# >>> x = xp.arange(start, stop, step, dtype=xp.uint64)
# >>> x.size
# 100
# >>> r = range(start, stop, step)
# >>> len(r)
# 101
#
min_size = math.floor(size * 0.9)
max_size = max(math.ceil(size * 1.1), 1)
assert (
min_size <= out.size <= max_size
), f"{out.size=}, but should be roughly {size} {f_func}"
if dh.is_int_dtype(_dtype):
assert out.size == size, f"{out.size=}, but should be {size} {f_func}"
else:
# We check size is roughly as expected to avoid edge cases e.g.
#
# >>> xp.arange(2, step=0.333333333333333)
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0]
# >>> xp.arange(2, step=0.3333333333333333)
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
#
min_size = math.floor(size * 0.9)
max_size = max(math.ceil(size * 1.1), 1)
assert (
min_size <= out.size <= max_size
), f"{out.size=}, but should be roughly {size} {f_func}"
assume(out.size == size)
if dh.is_int_dtype(_dtype):
ah.assert_exactly_equal(out, ah.asarray(list(r), dtype=_dtype))
elements = list(r)
assume(out.size == len(elements))
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
else:
assume(out.size == size)
if out.size > 0:
assert ah.equal(
out[0], ah.asarray(_start, dtype=out.dtype)
Expand Down

0 comments on commit abafee1

Please sign in to comment.