Skip to content

Commit

Permalink
Merge pull request #233 from honno/fft-fixes
Browse files Browse the repository at this point in the history
FFT fixes
  • Loading branch information
honno committed Feb 16, 2024
2 parents d0d9696 + e039ffb commit ebed2d6
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 78 deletions.
169 changes: 91 additions & 78 deletions array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,7 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
if axes is None:
s_strat = st.none() | s_strat
s = data.draw(s_strat, label="s")
if size_gt_1:
_s = x.shape if s is None else s
for i in range(x.ndim):
if i in _axes:
side = _s[_axes.index(i)]
else:
side = x.shape[i]
assume(side > 1)

norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
kwargs = data.draw(
hh.specified_kwargs(
Expand All @@ -86,14 +79,14 @@ def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -
return s, axes, norm, kwargs


def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
def assert_float_to_complex_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType
):
if in_dtype == xp.float32:
expected = xp.complex64
elif in_dtype == xp.float64:
expected = xp.complex128
else:
assert dh.is_float_dtype(in_dtype) # sanity check
expected = in_dtype
assert in_dtype == xp.float64 # sanity check
expected = xp.complex128
ph.assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
)
Expand All @@ -106,14 +99,10 @@ def assert_n_axis_shape(
n: Optional[int],
axis: int,
out: Array,
size_gt_1: bool = False,
):
_axis = len(x.shape) - 1 if axis == -1 else axis
if n is None:
if size_gt_1:
axis_side = 2 * (x.shape[_axis] - 1)
else:
axis_side = x.shape[_axis]
axis_side = x.shape[_axis]
else:
axis_side = n
expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
Expand All @@ -127,7 +116,6 @@ def assert_s_axes_shape(
s: Optional[List[int]],
axes: Optional[List[int]],
out: Array,
size_gt_1: bool = False,
):
_axes = sh.normalise_axis(axes, x.ndim)
_s = x.shape if s is None else s
Expand All @@ -138,88 +126,78 @@ def assert_s_axes_shape(
else:
side = x.shape[i]
expected.append(side)
if size_gt_1:
last_axis = _axes[-1]
expected[last_axis] = 2 * (expected[last_axis] - 1)
assume(expected[last_axis] > 0) # TODO: generate valid examples
ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))


@given(
x=hh.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_fft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.fft(x, **kwargs)

assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)


@given(
x=hh.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_ifft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.ifft(x, **kwargs)

assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)


@given(
x=hh.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_fftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.fftn(x, **kwargs)

assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)


@given(
x=hh.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_ifftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.ifftn(x, **kwargs)

assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)


@given(
x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_rfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.rfft(x, **kwargs)

assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out)
assert_float_to_complex_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
axis_side = x.shape[_axis] // 2 + 1
else:
axis_side = n // 2 + 1
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
ph.assert_shape("rfft", out_shape=out.shape, expected=expected_shape)


@given(
x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_irfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)

out = xp.fft.irfft(x, **kwargs)

assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype(
"irfft",
in_dtype=x.dtype,
out_dtype=out.dtype,
expected=dh.dtype_components[x.dtype],
)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
Expand All @@ -230,17 +208,25 @@ def test_irfft(x, data):
ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)


@given(
x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_rfftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.rfftn(x, **kwargs)

assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out)
assert_float_to_complex_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)

_axes = sh.normalise_axis(axes, x.ndim)
_s = x.shape if s is None else s
expected = []
for i in range(x.ndim):
if i in _axes:
side = _s[_axes.index(i)]
else:
side = x.shape[i]
expected.append(side)
expected[_axes[-1]] = _s[-1] // 2 + 1
ph.assert_shape("rfftn", out_shape=out.shape, expected=tuple(expected))


@given(
Expand All @@ -250,24 +236,44 @@ def test_rfftn(x, data):
data=st.data(),
)
def test_irfftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True)
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.irfftn(x, **kwargs)

assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True)

ph.assert_dtype(
"irfftn",
in_dtype=x.dtype,
out_dtype=out.dtype,
expected=dh.dtype_components[x.dtype],
)

@given(
x=hh.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
# TODO: assert shape correctly
# _axes = sh.normalise_axis(axes, x.ndim)
# _s = x.shape if s is None else s
# expected = []
# for i in range(x.ndim):
# if i in _axes:
# side = _s[_axes.index(i)]
# else:
# side = x.shape[i]
# expected.append(side)
# last_axis = max(_axes)
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))


@given(x=hh.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_hfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)

out = xp.fft.hfft(x, **kwargs)

assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_dtype(
"hfft",
in_dtype=x.dtype,
out_dtype=out.dtype,
expected=dh.dtype_components[x.dtype],
)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
Expand All @@ -278,20 +284,24 @@ def test_hfft(x, data):
ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)


@given(
x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
@given(x=hh.arrays(dtype=xps.floating_dtypes(), shape=fft_shapes_strat), data=st.data())
def test_ihfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.ihfft(x, **kwargs)

assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True)
assert_float_to_complex_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
axis_side = x.shape[_axis] // 2 + 1
else:
axis_side = n // 2 + 1
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape)


@given( n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
def test_fftfreq(n, kw):
out = xp.fft.fftfreq(n, **kw)
ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
Expand All @@ -300,15 +310,18 @@ def test_fftfreq(n, kw):
@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
def test_rfftfreq(n, kw):
out = xp.fft.rfftfreq(n, **kw)
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})
ph.assert_shape(
"rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n}
)


@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat), data=st.data())
def test_shift_func(func_name, x, data):
func = getattr(xp.fft, func_name)
axes = data.draw(
st.none() | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
st.none()
| st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
label="axes",
)
out = func(x, axes=axes)
Expand Down
1 change: 1 addition & 0 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def test_sum(x, data):
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)


@pytest.mark.skip(reason="flaky") # TODO: fix!
@given(
x=hh.arrays(
dtype=xps.floating_dtypes(),
Expand Down

0 comments on commit ebed2d6

Please sign in to comment.