Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dask.array.fft mismatch with Numpy's interface (add support for norm argument) #10665

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 8 additions & 10 deletions dask/array/fft.py
Expand Up @@ -144,10 +144,8 @@ def fft_wrap(fft_func, kind=None, dtype=None):
>>> parallel_ifft = dff.fft_wrap(np.fft.ifft)
"""
if scipy is not None:
if fft_func is scipy.fftpack.rfft:
raise ValueError("SciPy's `rfft` doesn't match the NumPy API.")
elif fft_func is scipy.fftpack.irfft:
raise ValueError("SciPy's `irfft` doesn't match the NumPy API.")
if fft_func.__module__.startswith("scipy.fftpack"):
raise ValueError("SciPy's `fftpack` functions don't match the NumPy API.")

if kind is None:
kind = fft_func.__name__
Expand All @@ -156,7 +154,7 @@ def fft_wrap(fft_func, kind=None, dtype=None):
except KeyError:
raise ValueError("Given unknown `kind` %s." % kind)

def func(a, s=None, axes=None):
def func(a, s=None, axes=None, norm=None):
a = asarray(a)
if axes is None:
if kind.endswith("2"):
Expand All @@ -176,7 +174,7 @@ def func(a, s=None, axes=None):
if _dtype is None:
sample = np.ones(a.ndim * (8,), dtype=a.dtype)
try:
_dtype = fft_func(sample, axes=axes).dtype
_dtype = fft_func(sample, axes=axes, norm=norm).dtype
except TypeError:
_dtype = fft_func(sample).dtype

Expand All @@ -186,18 +184,18 @@ def func(a, s=None, axes=None):

chunks = out_chunk_fn(a, s, axes)

args = (s, axes)
args = (s, axes, norm)
if kind.endswith("fft"):
axis = None if axes is None else axes[0]
n = None if s is None else s[0]
args = (n, axis)
args = (n, axis, norm)

return a.map_blocks(fft_func, *args, dtype=_dtype, chunks=chunks)

if kind.endswith("fft"):
_func = func

def func(a, n=None, axis=None):
def func(a, n=None, axis=None, norm=None):
s = None
if n is not None:
s = (n,)
Expand All @@ -206,7 +204,7 @@ def func(a, n=None, axis=None):
if axis is not None:
axes = (axis,)

return _func(a, s, axes)
return _func(a, s, axes, norm)

func_mod = inspect.getmodule(fft_func)
func_name = fft_func.__name__
Expand Down
58 changes: 33 additions & 25 deletions dask/array/tests/test_fft.py
Expand Up @@ -67,10 +67,22 @@ def test_fft_n_kwarg(funcname):

assert_eq(da_fft(darr, 5), np_fft(nparr, 5))
assert_eq(da_fft(darr, 13), np_fft(nparr, 13))
assert_eq(da_fft(darr, 13, norm="backward"), np_fft(nparr, 13, norm="backward"))
assert_eq(da_fft(darr, 13, norm="ortho"), np_fft(nparr, 13, norm="ortho"))
assert_eq(da_fft(darr, 13, norm="forward"), np_fft(nparr, 13, norm="forward"))
assert_eq(da_fft(darr2, axis=0), np_fft(nparr, axis=0))
assert_eq(da_fft(darr2, 5, axis=0), np_fft(nparr, 5, axis=0))
assert_eq(da_fft(darr2, 13, axis=0), np_fft(nparr, 13, axis=0))
assert_eq(da_fft(darr2, 12, axis=0), np_fft(nparr, 12, axis=0))
assert_eq(
da_fft(darr2, 13, axis=0, norm="backward"),
np_fft(nparr, 13, axis=0, norm="backward"),
)
assert_eq(
da_fft(darr2, 12, axis=0, norm="ortho"), np_fft(nparr, 12, axis=0, norm="ortho")
)
assert_eq(
da_fft(darr2, 12, axis=0, norm="forward"),
np_fft(nparr, 12, axis=0, norm="forward"),
)


@pytest.mark.parametrize("funcname", all_1d_funcnames)
Expand Down Expand Up @@ -115,7 +127,7 @@ def test_nd_ffts_axes(funcname, dtype):
assert_eq(r, er)


@pytest.mark.parametrize("modname", ["numpy.fft", "scipy.fftpack"])
@pytest.mark.parametrize("modname", ["numpy.fft", "scipy.fft"])
@pytest.mark.parametrize("funcname", all_1d_funcnames)
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_wrap_ffts(modname, funcname, dtype):
Expand All @@ -129,28 +141,24 @@ def test_wrap_ffts(modname, funcname, dtype):
darr2c = darr2.astype(dtype)
nparrc = nparr.astype(dtype)

if modname == "scipy.fftpack" and "rfft" in funcname:
with pytest.raises(ValueError):
fft_wrap(func)
else:
wfunc = fft_wrap(func)
assert wfunc(darrc).dtype == func(nparrc).dtype
assert wfunc(darrc).shape == func(nparrc).shape
assert_eq(wfunc(darrc), func(nparrc))
assert_eq(wfunc(darrc, axis=1), func(nparrc, axis=1))
assert_eq(wfunc(darr2c, axis=0), func(nparrc, axis=0))
assert_eq(wfunc(darrc, n=len(darrc) - 1), func(nparrc, n=len(darrc) - 1))
assert_eq(
wfunc(darrc, axis=1, n=darrc.shape[1] - 1),
func(nparrc, n=darrc.shape[1] - 1),
)
assert_eq(
wfunc(darr2c, axis=0, n=darr2c.shape[0] - 1),
func(nparrc, axis=0, n=darr2c.shape[0] - 1),
)


@pytest.mark.parametrize("modname", ["numpy.fft", "scipy.fftpack"])
wfunc = fft_wrap(func)
assert wfunc(darrc).dtype == func(nparrc).dtype
assert wfunc(darrc).shape == func(nparrc).shape
assert_eq(wfunc(darrc), func(nparrc))
assert_eq(wfunc(darrc, axis=1), func(nparrc, axis=1))
assert_eq(wfunc(darr2c, axis=0), func(nparrc, axis=0))
assert_eq(wfunc(darrc, n=len(darrc) - 1), func(nparrc, n=len(darrc) - 1))
assert_eq(
wfunc(darrc, axis=1, n=darrc.shape[1] - 1),
func(nparrc, n=darrc.shape[1] - 1),
)
assert_eq(
wfunc(darr2c, axis=0, n=darr2c.shape[0] - 1),
func(nparrc, axis=0, n=darr2c.shape[0] - 1),
)


@pytest.mark.parametrize("modname", ["numpy.fft", "scipy.fft"])
@pytest.mark.parametrize("funcname", all_nd_funcnames)
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_wrap_fftns(modname, funcname, dtype):
Expand Down