Skip to content

Commit

Permalink
Merge pull request PyWavelets#723 from rgommers/fix-numpy2-support
Browse files Browse the repository at this point in the history
MAINT: use `numpy-config` and fix support for numpy 2.0
  • Loading branch information
rgommers committed Mar 13, 2024
2 parents 3c669e1 + 9ab7dac commit 2b5f587
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 52 deletions.
3 changes: 2 additions & 1 deletion pywt/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_check_dtype,
)
from ._functions import integrate_wavelet, scale2frequency
from ._utils import AxisError

__all__ = ["cwt"]

Expand Down Expand Up @@ -124,7 +125,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
raise ValueError("`scales` must only include positive values")

if not np.isscalar(axis):
raise np.AxisError("axis must be a scalar.")
raise AxisError("axis must be a scalar.")

dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
Expand Down
6 changes: 3 additions & 3 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._extensions._dwt import dwt_max_level as _dwt_max_level
from ._extensions._dwt import upcoef as _upcoef
from ._extensions._pywt import Modes, Wavelet, _check_dtype, wavelist
from ._utils import _as_wavelet
from ._utils import AxisError, _as_wavelet

__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
"dwt_coeff_len", "pad"]
Expand Down Expand Up @@ -176,7 +176,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

if data.ndim == 1:
cA, cD = dwt_single(data, wavelet, mode)
Expand Down Expand Up @@ -282,7 +282,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + ndim
if not 0 <= axis < ndim:
raise np.AxisError("Axis greater than coefficient dimensions")
raise AxisError("Axis greater than coefficient dimensions")

if ndim == 1:
rec = idwt_single(cA, cD, wavelet, mode)
Expand Down
63 changes: 31 additions & 32 deletions pywt/_extensions/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,44 @@ if m_dep.found()
add_project_link_arguments('-lm', language : 'c')
endif

# For cross-compilation it is often not possible to run the Python interpreter
# in order to retrieve numpy's include directory. It can be specified in the
# cross file instead:
# [properties]
# numpy-include-dir = /abspath/to/host-pythons/site-packages/numpy/core/include
#
# This uses the path as is, and avoids running the interpreter.
incdir_numpy = meson.get_external_property('numpy-include-dir', 'not-given')
if incdir_numpy == 'not-given'
incdir_numpy = run_command(py,
[
'-c',
'''import os
# Don't use the deprecated NumPy C API. Define this to a fixed version instead of
# NPY_API_VERSION in order not to break compilation for released PyWavelets
# versions when NumPy introduces a new deprecation.
numpy_nodepr_api = ['-DNPY_NO_DEPRECATED_API=NPY_1_22_API_VERSION']

# Uses the `numpy-config` executable (or a user's numpy.pc pkg-config file),
# will work for numpy>=2.0.0b1 and meson>=1.4.0
_numpy_dep = dependency('numpy', required: false)
if _numpy_dep.found()
np_dep = declare_dependency(dependencies: _numpy_dep, compile_args: numpy_nodepr_api)
else
# For cross-compilation it is often not possible to run the Python interpreter
# in order to retrieve numpy's include directory. It can be specified in the
# cross file instead:
# [properties]
# numpy-include-dir = /abspath/to/host-pythons/site-packages/numpy/core/include
#
# This uses the path as is, and avoids running the interpreter.
incdir_numpy = meson.get_external_property('numpy-include-dir', 'not-given')
if incdir_numpy == 'not-given'
incdir_numpy = run_command(py,
[
'-c',
'''import os
import numpy as np
try:
incdir = os.path.relpath(np.get_include())
except Exception:
incdir = np.get_include()
print(incdir)
'''
],
check: true
).stdout().strip()

# We do need an absolute path to feed to `cc.find_library` below
_incdir_numpy_abs = run_command(py,
['-c', 'import os; os.chdir(".."); import numpy; print(numpy.get_include())'],
check: true
).stdout().strip()
else
_incdir_numpy_abs = incdir_numpy
'''
],
check: true
).stdout().strip()
endif
inc_np = include_directories(incdir_numpy)
np_dep = declare_dependency(include_directories: inc_np, compile_args: numpy_nodepr_api)
endif
inc_np = include_directories(incdir_numpy)

# Don't use the deprecated NumPy C API. Define this to a fixed version instead of
# NPY_API_VERSION in order not to break compilation for released PyWavelets
# versions when NumPy introduces a new deprecation.
numpy_nodepr_api = ['-DNPY_NO_DEPRECATED_API=NPY_1_22_API_VERSION']
np_dep = declare_dependency(include_directories: inc_np, compile_args: numpy_nodepr_api)

config_pxi = configure_file(
input: 'config.pxi.in',
Expand Down
4 changes: 2 additions & 2 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ._c99_config import _have_c99_complex
from ._extensions._dwt import dwt_axis, idwt_axis
from ._utils import _modes_per_axis, _wavelets_per_axis
from ._utils import AxisError, _modes_per_axis, _wavelets_per_axis

__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']

Expand Down Expand Up @@ -288,7 +288,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
for key_length, (axis, wav, mode) in reversed(
list(enumerate(zip(axes, wavelets, modes)))):
if axis < 0 or axis >= ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

new_coeffs = {}
new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]
Expand Down
14 changes: 7 additions & 7 deletions pywt/_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._extensions._dwt import dwt_max_level
from ._extensions._pywt import Modes, Wavelet
from ._multidim import _fix_coeffs, dwt2, dwtn, idwt2, idwtn
from ._utils import _as_wavelet, _modes_per_axis, _wavelets_per_axis
from ._utils import AxisError, _as_wavelet, _modes_per_axis, _wavelets_per_axis

__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2', 'wavedecn',
'waverecn', 'coeffs_to_array', 'array_to_coeffs', 'ravel_coeffs',
Expand Down Expand Up @@ -93,7 +93,7 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1):
try:
axes_shape = data.shape[axis]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
level = _check_level(axes_shape, wavelet.dec_len, level)

coeffs_list = []
Expand Down Expand Up @@ -170,7 +170,7 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
elif a.shape[axis] != d.shape[axis]:
raise ValueError("coefficient shape mismatch")
except IndexError:
raise np.AxisError("Axis greater than coefficient dimensions")
raise AxisError("Axis greater than coefficient dimensions")
a = idwt(a, d, wavelet, mode, axis)

return a
Expand Down Expand Up @@ -233,7 +233,7 @@ def wavedec2(data, wavelet, mode='symmetric', level=None, axes=(-2, -1)):
try:
axes_sizes = [data.shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

wavelets = _wavelets_per_axis(wavelet, axes)
dec_lengths = [w.dec_len for w in wavelets]
Expand Down Expand Up @@ -352,7 +352,7 @@ def _prep_axes_wavedecn(shape, axes):
try:
axes_shapes = [shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
ndim_transform = len(axes)
return axes, axes_shapes, ndim_transform

Expand Down Expand Up @@ -1194,11 +1194,11 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'):
def _check_fswavedecn_axes(data, axes):
"""Axes checks common to fswavedecn, fswaverecn."""
if len(axes) != len(set(axes)):
raise np.AxisError("The axes passed to fswavedecn must be unique.")
raise AxisError("The axes passed to fswavedecn must be unique.")
try:
[data.shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")


class FswavedecnResult:
Expand Down
8 changes: 4 additions & 4 deletions pywt/_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._extensions._swt import swt_axis as _swt_axis
from ._extensions._swt import swt_max_level
from ._multidim import idwt2, idwtn
from ._utils import _as_wavelet, _wavelets_per_axis
from ._utils import AxisError, _as_wavelet, _wavelets_per_axis

__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']

Expand Down Expand Up @@ -141,7 +141,7 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1,
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

if level is None:
level = swt_max_level(data.shape[axis])
Expand Down Expand Up @@ -196,7 +196,7 @@ def iswt(coeffs, wavelet, norm=False, axis=-1):
coeffs_nd = [{'a': a, 'd': d} for a, d in coeffs]
return iswtn(coeffs_nd, wavelet, axes=(axis,), norm=norm)
elif axis != 0 and axis != -1:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
if not _have_c99_complex and np.iscomplexobj(cA):
if trim_approx:
coeffs_real = [c.real for c in coeffs]
Expand Down Expand Up @@ -639,7 +639,7 @@ def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False,
axes = range(data.ndim)
axes = [a + data.ndim if a < 0 else a for a in axes]
if any(a < 0 or a >= data.ndim for a in axes):
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
if len(axes) != len(set(axes)):
raise ValueError("The axes passed to swtn must be unique.")
num_axes = len(axes)
Expand Down
6 changes: 6 additions & 0 deletions pywt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
Wavelet,
)

AxisError: type[Exception]
if np.lib.NumpyVersion(np.__version__) >= '1.25.0':
from numpy.exceptions import AxisError
else:
from numpy import AxisError


def _as_wavelet(wavelet):
"""Convert wavelet name to a Wavelet object."""
Expand Down
7 changes: 4 additions & 3 deletions pywt/tests/test_mra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pywt
from pywt import data
from pywt._utils import AxisError

# tolerances used in accuracy comparisons
tol_single = 1e-6
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_mra_axis(transform, ndim, axis, dtype):

# out of range axis
if axis < -x.ndim or axis >= x.ndim:
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mra(x, 'db1', transform=transform, axis=axis)
return

Expand Down Expand Up @@ -160,7 +161,7 @@ def test_mra2_axes(transform, axes, ndim, dtype):

# out of range axis
if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mra2(x, 'db1', transform=transform, axes=axes)
return

Expand Down Expand Up @@ -246,7 +247,7 @@ def test_mran_axes(axes, transform):

# out of range axis
if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
return

Expand Down

0 comments on commit 2b5f587

Please sign in to comment.