-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: ENH: fft: support array API #2
Conversation
494f07f
to
ea0075e
Compare
adds test utility for checking that namespaces match Co-authored-by: Tyler Reddy <tyler.je.reddy@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lucascolley 🚀 This is moving good.
I am thinking about the real PR you are going to open. I anticipate that the main critic you will get will be around the refactoring (move/rename) part. Ideally, the move part would be done in a single commit, the first one. This way we can tell maintainers to review this commit separately from the rest (in GitHub you can filter the commits you are looking at and here it will greatly help.)
I am likely missing a reason and/or previous discussion here, but: why are these large code moves / file renames necessary? It seems to me like it's splitting out the
To verify the docstring behavior:
Also a minor comment on code style: try to not do |
Great, this now looks much easier to read and review, and I like that the |
d050836
to
5f45410
Compare
@lucascolley in every # The functions in this file decorated with `@_dispatch` will only be called
# from the correspondingly named public functions with NumPy arrays or
# array-likes, not with CuPy, PyTorch and other array API standard supporting
# objects. See the docstrings in `_backend.py` for more details and examples. Let's leave out the longer comment in |
Closing now as the full PR has been opened on the main repo: scipy#19005 |
Reference issue
RFC for array API support in SciPy: scipy#18286
Tracker for array API support in SciPy: scipy#18867
Merged PR adding machinery and covering
scipy.cluster
: scipy#18668(This PR follows on from issue tupui#26)
What does this implement/fix?
This PR extends the array API support of scipy#18668 to also cover
scipy.fft
.fft
is an array API standard extension module and has matching APIs in both CuPy and PyTorch. The goal is to fully coverscipy.fft
to conform to the array API standard, as specified in the array API standard documentation:For more context on the standard and the goal of SciPy adoption, please see the RFC.
For explanation of the machinery of which this PR makes use, please see the
cluster
PR.Changes:
uarray
dispatch structure of_basic.py
has been moved to a new file,_basic_uarray.py
, which is used for NumPy arrays and array-likes. All input arrays are now validated byarray_namespace
, so all valid input arrays either support the array API standard or are converted to NumPy arrays. The functions of_basic.py
fall into two categories in their handling of non-NumPy arrays:fft
module. If it does, then that is used. If not then a conversion to NumPy is attempted in order to useuarray
. If successful, the result is converted back to the same namespace as the input array. For an example, seefft.fft
.uarray
. If successful, the result is converted back to the same namespace as the input array. For an example, seefft.fft2
. This means that GPU support is restricted to functions which are part of the standard for now, as otherwise they are hitting compiled code made for NumPy arrays. This design decision is explained in the RFC:RFC excerpt (expandable)
_realtransforms.py
and_realtransforms_uarray.py
, save for the fact that there are no standard functions in these files.fftfreq
,rfftfreq
,fftshift
andifftshift
were previously just imported fromnumpy
. They now have implementations in_helper.py
, and follow the pattern for standard functions._fftlog.py
has been moved to a new file,_fftlog_np.py
. Itsuarray
dispatch structure remains in the file_fftlog_multimethods.py
. The functions now follow the pattern for non-standard functions.fftfreq
etc. have been copied over from NumPy totest_helper.py
.TestNamespaces
has been added totest_basic.py
andtest_helper.py
to check that output arrays are of the same namespace as input arrays.test_non_standard_params
has been added totest_basic.py
to check that exceptions are raised for unsupported parameters.conftest.py
:set_assert_allclose(xp)
, which returns anassert_allclose
equivalent for the namespacexp
(currently supportscupy
andtorch
), to allow testing on GPU._array_api.py
:is_numpy(xp)
which checks whetherxp.__name__
is"scipy._lib.array_api_compat.array_api_compat.numpy"
, and_assert_matching_namespace
by Tyler Reddy.Additional information
To-do:
scipy/scipy/fft/tests/test_fftlog.py
Lines 118 to 122 in f5bc71a
Difficulties:
rtol
are appropriate, that would be greatly appreciated.test_basic.py::test_multiprocess
is currently skipped for array API since it hangs for any array API library (evennumpy.array_api
). Generally, multithreading tests (including those using the unsupportedworkers
keyword intest_multithreading.py
are skipped apart fromTestFFTThreadSafe
intest_basic.py
.@skip_if_array_api_backend
and@pytest.mark.parametrize
don't seem to work together, raising exceptions for the backends which aren't skipped, no matter which decorator comes first. If anyone could get this working, that would be great!@skip_if_array_api_gpu
currently skips on every backend whenpython dev.py test -b all
is used, due to a fault in its logic. This is not critical since the tests can be run separately for each backend, but this is unintended behaviour which should be fixed.Implementation Details:
fftfreq
andrfftfreq
have changed to(n, d=1.0, *, xp=None, device=None)
. The addition ofdevice
is to match the signature in the standard. The addition ofxp
is to specify the namespace for the output array. This is required since these are "array-generating functions" which output an array without taking one as input. Since array-api-compat andnumpy.array_api
are yet to implementfft
, these functions and their tests currently require workarounds/pytest.xfail
's. A lot of cleanup will be possible at a later date.scipy/scipy/fft/_helper.py
Lines 147 to 152 in 2cb6268
scipy/scipy/fft/tests/test_helper.py
Lines 444 to 449 in a097d91
Perhaps a warning would be good to state that a NumPy array is returned unless the
xp
parameter is provided?test_basic.py::test_fft_with_order
is skipped since orders are not supported in the array API. Likewise,test_identity_1d_overwrite
andtest_identity_nd_overwrite
intest_real_transforms.py
are skipped since theoverwrite_x
keyword is not supported.test_basic.py::TestFFT1D
andtest_helper.py::TestFFTShift
are failing ontorch
backend due to differences in the function signatures, but they will pass once array-api-compat has implementedtorch.fft
. Currently either@skip_if_array_api_backend('torch')
orpytest.xfail
is used.set_assert_allclose
instead of just converting to NumPy and usingassert_allclose
: to_device() -- any way to force back to host "portably?" data-apis/array-api#626. Additional dtype casts are needed fortorch.assert_close
e.g. intest_float32
(a part of the previoustest_dtypes
) and intest_basic.py::TestFFTThreadSafe
.Follow-up:
scipy/scipy/fft/tests/test_basic.py
Lines 583 to 586 in 5f45410
fftfreq
andrfftfreq
.device
keyword offftfreq
andrfftfreq
are not currently tested since the keyword only works with PyTorch currently. But this should probably be tested once more array libraries support it._fftlog_multimethods.py
to_fftlog_uarray.py
. Currently omitted to keep the diff clean.fft2
could be rewritten in terms of the standard functions in order to avoid conversion to NumPy, thus allowing them to run on GPU.Many thanks to Pamphile, Irwin and Ralf who have given really helpful guidance in the making of this PR!
References:
numpy.array_api
: https://numpy.org/devdocs/reference/array_api.html