Skip to content

Commit

Permalink
Support np.apply_over_axes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mhvk committed Jun 21, 2019
1 parent b4f9ba0 commit 2eb5412
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
42 changes: 35 additions & 7 deletions astropy/units/quantity_helper/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
in, which will be filled in-place.
For the DISPATCHED_FUNCTIONS `dict`, the value is a function that
implements the numpy functionality for Quantity input. It should return
a tuple of ``result, unit, out``, where ``result`` is a plain array
with the result, and ``unit`` and ``out`` are as above.
implements the numpy functionality for Quantity input. It should
return a tuple of ``result, unit, out``, where ``result`` is generally
a plain array with the result, and ``unit`` and ``out`` are as above.
If unit is `None`, result gets returned directly, so one can also
return a Quantity directly using ``quantity_result, None, None``.
"""

import functools
Expand Down Expand Up @@ -111,10 +114,6 @@
np.isclose, np.allclose,
np.array2string, np.array_repr, np.array_str}

# TODO: could be supported but need work & thought.
UNSUPPORTED_FUNCTIONS |= {
np.apply_over_axes}

# Nonsensical for quantities.
UNSUPPORTED_FUNCTIONS |= {
np.packbits, np.unpackbits, np.unravel_index,
Expand Down Expand Up @@ -801,3 +800,32 @@ def setcheckop(ar1, ar2, *args, **kwargs):
# a1 to that of a2.
(ar2, ar1), unit = _quantities2arrays(ar2, ar1)
return (ar1, ar2) + args, kwargs, None, None


@dispatched_function
def apply_over_axes(func, a, axes):
# Copied straight from numpy/lib/shape_base, just to omit its
# val = asarray(a); if only it had been asanyarray, or just not there
# since a is assumed to an an array in the next line...
# Which is what we do here - we can only get here if it is a Quantity.
val = a
N = a.ndim
if np.array(axes).ndim == 0:
axes = (axes,)
for axis in axes:
if axis < 0:
axis = N + axis
args = (val, axis)
res = func(*args)
if res.ndim == val.ndim:
val = res
else:
res = np.expand_dims(res, axis)
if res.ndim == val.ndim:
val = res
else:
raise ValueError("function is not returning "
"an array of the correct shape")
# Returning unit is None to signal nothing should happen to
# the output.
return val, None, None
16 changes: 13 additions & 3 deletions astropy/units/tests/test_quantity_non_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,18 @@ def test_apply_along_axis(self, axis):
self.q.value) * self.q.unit ** 2
assert_array_equal(out, expected)

@pytest.mark.xfail(NO_ARRAY_FUNCTION,
reason="Needs __array_function__ support")
@pytest.mark.parametrize('axes', ((1,), (0,), (0, 1)))
def test_apply_over_axes(self, axes):
def function(x, axis):
return np.sum(np.square(x), axis)

out = np.apply_over_axes(function, self.q, axes)
expected = np.apply_over_axes(function, self.q.value, axes)
expected = expected * self.q.unit ** (2 * len(axes))
assert_array_equal(out, expected)


class TestIndicesFrom(NoUnitTestSetup):
def test_diag_indices_from(self):
Expand Down Expand Up @@ -1734,9 +1746,7 @@ def test_is_busday(self):
reason="no __array_function__ wrapping in numpy<1.17")
def test_testing_completeness():
assert not CoverageMeta.covered.intersection(untested_functions)
assert all_wrapped == (CoverageMeta.covered |
should_be_tested_functions |
untested_functions)
assert all_wrapped == (CoverageMeta.covered | untested_functions)


class TestFunctionHelpersCompleteness:
Expand Down

0 comments on commit 2eb5412

Please sign in to comment.