Skip to content

Commit

Permalink
Added ufunc_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
izaid committed Apr 20, 2024
1 parent d7ea1cf commit 3bc2cf1
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 23 deletions.
27 changes: 11 additions & 16 deletions scipy/special/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
psi, hankel1, hankel2, yv, kv, poch, binom,
_stirling2_inexact)
from ._special_ufuncs import lpn as _lpn
from ._gufuncs import (lpn_all as _lpn_all, _lpmn, _clpmn, _lqn, _lqmn, _rctj, _rcty,
from ._gufuncs import (lpn_all as _lpn_all, lpn_all_until_jac as _lpn_all_until_jac, _lpmn, _clpmn, _lqn, _lqmn, _rctj, _rcty,
_sph_harm_all as _sph_harm_all_gufunc)
from . import _specfun
from ._comb import _comb_int
from scipy._lib.deprecation import _NoValue, _deprecate_positional_args

from ._ufunc_wrapper import ufunc_wrapper

__all__ = [
'ai_zeros',
Expand Down Expand Up @@ -2044,8 +2044,10 @@ def euler(n):
n1 = n
return _specfun.eulerb(n1)[:(n+1)]

lpn_all = ufunc_wrapper(_lpn_all, until_diffs = (_lpn_all_until_jac,))

def lpn_all(n, z):
@lpn_all.resolve_out_shapes
def _(n, shape):
"""Legendre function of the first kind.
Compute sequence of Legendre functions of the first kind (polynomials),
Expand All @@ -2060,25 +2062,18 @@ def lpn_all(n, z):
https://people.sc.fsu.edu/~jburkardt/f77_src/special_functions/special_functions.html
"""
n = _nonneg_int_or_fail(n, 'n', strict=False)

z = np.asarray(z)
if (not np.issubdtype(z.dtype, np.inexact)):
z = z.astype(np.float64)
n = _nonneg_int_or_fail(n, 'n', strict=False)

pn = np.empty((n + 1,) + z.shape, dtype=z.dtype)
pd = np.empty_like(pn)
if (z.ndim == 0):
_lpn_all(z, out = (pn, pd))
else:
_lpn_all(z, out = (np.moveaxis(pn, 0, -1),
np.moveaxis(pd, 0, -1))) # new axes must be last for the ufunc
return (n + 1,) + shape[0]

return pn, pd
@lpn_all.as_ufunc_out
def _(out):
return np.moveaxis(out, 0, -1)

def lpn(n, z, legacy = True):
if legacy:
return lpn_all(n, z)
return lpn_all.until_jac(n, z)

return _lpn(n, z)

Expand Down
18 changes: 15 additions & 3 deletions scipy/special/_gufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

using namespace std;

using func_f_f1_t = void (*)(float, mdspan<float, dextents<ptrdiff_t, 1>, layout_stride>);
using func_d_d1_t = void (*)(double, mdspan<double, dextents<ptrdiff_t, 1>, layout_stride>);
using func_F_F1_t = void (*)(complex<float>, mdspan<complex<float>, dextents<ptrdiff_t, 1>, layout_stride>);
using func_D_D1_t = void (*)(complex<double>, mdspan<complex<double>, dextents<ptrdiff_t, 1>, layout_stride>);

using func_f_f1f1_t =
void (*)(float, mdspan<float, dextents<ptrdiff_t, 1>, layout_stride>, mdspan<float, dextents<ptrdiff_t, 1>, layout_stride>);
using func_d_d1d1_t =
Expand Down Expand Up @@ -74,12 +79,19 @@ PyMODINIT_FUNC PyInit__gufuncs() {
}

PyObject *lpn_all = SpecFun_NewGUFunc(
{static_cast<func_f_f1f1_t>(::lpn_all), static_cast<func_d_d1d1_t>(::lpn_all),
static_cast<func_F_F1F1_t>(::lpn_all), static_cast<func_D_D1D1_t>(::lpn_all)},
2, "lpn_all", lpn_all_doc, "()->(np1),(np1)"
{static_cast<func_f_f1_t>(::lpn_all), static_cast<func_d_d1_t>(::lpn_all), static_cast<func_F_F1_t>(::lpn_all),
static_cast<func_D_D1_t>(::lpn_all)},
1, "lpn_all", lpn_all_doc, "()->(np1)"
);
PyModule_AddObjectRef(_gufuncs, "lpn_all", lpn_all);

PyObject *lpn_all_until_jac = SpecFun_NewGUFunc(
{static_cast<func_f_f1f1_t>(::lpn_all_until_jac), static_cast<func_d_d1d1_t>(::lpn_all_until_jac),
static_cast<func_F_F1F1_t>(::lpn_all_until_jac), static_cast<func_D_D1D1_t>(::lpn_all_until_jac)},
2, "lpn_all_until_jac", lpn_all_doc, "()->(np1),(np1)"
);
PyModule_AddObjectRef(_gufuncs, "lpn_all_until_jac", lpn_all_until_jac);

PyObject *_lpmn = SpecFun_NewGUFunc(
{static_cast<func_bf_f2f2_t>(::lpmn), static_cast<func_bd_d2d2_t>(::lpmn)}, 2, "_lpmn", lpmn_doc,
"(),()->(mp1,np1),(mp1,np1)"
Expand Down
46 changes: 46 additions & 0 deletions scipy/special/_ufunc_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

import numpy as np

class ufunc_wrapper(object):
def __init__(self, ufunc, diffs = (), until_diffs = ()):
if until_diffs:
until_diffs = tuple(ufunc_wrapper(until_diff) for until_diff in until_diffs)

self.ufunc = ufunc
self.until_diffs = until_diffs
self._as_ufunc_out = lambda out: out

@property
def until_jac(self):
return self.until_diffs[0]

def resolve_out_shapes(self, func):
for k, until_diff in enumerate(self.until_diffs, 1):
until_diff.resolve_out_shapes(lambda *args: (k + 1) * (func(*args),))

self._resolve_out_shapes = func

def as_ufunc_out(self, func):
for k, until_diff in enumerate(self.until_diffs, 1):
until_diff.as_ufunc_out(lambda out: tuple(func(out[i]) for i in range(len(out))))

self._as_ufunc_out = func

def __call__(self, *args):
resolve_out_shapes_args = args[:-self.ufunc.nin]
args = args[-self.ufunc.nin:]

arg_shapes = tuple(np.shape(arg) for arg in args)
out_shapes = self._resolve_out_shapes(*resolve_out_shapes_args, arg_shapes)

arg_dtypes = tuple(arg.dtype if hasattr(arg, 'dtype') else type(arg) for arg in args) + self.ufunc.nout * (None,)
dtypes = self.ufunc.resolve_dtypes(arg_dtypes)
out_dtypes = dtypes[-self.ufunc.nout:]

out = tuple(np.empty(out_shape, dtype = out_dtype) for out_shape, out_dtype in zip(out_shapes, out_dtypes))
self.ufunc(*args, out = self._as_ufunc_out(out))

if (len(out) == 1):
out, = out

return out
1 change: 1 addition & 0 deletions scipy/special/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ python_sources = [
'_test_internal.pyi',
'_testutils.py',
'_ufuncs.pyi',
'_ufunc_wrapper.py',
'add_newdocs.py',
'basic.py',
'cython_special.pxd',
Expand Down
7 changes: 6 additions & 1 deletion scipy/special/special.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ void lpn(long n, T z, T &p, T &p_jac) {
special::legendre_p_jac(n, z, p, p_jac);
}

template <typename T, typename OutputVec>
void lpn_all(T z, OutputVec p) {
special::legendre_p_all(z, p);
}

template <typename T, typename OutputVec1, typename OutputVec2>
void lpn_all(T z, OutputVec1 p, OutputVec2 p_jac) {
void lpn_all_until_jac(T z, OutputVec1 p, OutputVec2 p_jac) {
special::legendre_p_jac_all(z, p, p_jac);
}

Expand Down
6 changes: 3 additions & 3 deletions scipy/special/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,10 +3621,10 @@ def test_lpmn(self):
-0.12500]]),
array([[0.00000,
1.00000,
1.50000]])),4)
1.50000]])), 4)

def test_lpn_all(self):
p, pd = special.lpn_all(2, 0.5)
def test_lpn(self):
p, pd = special.lpn(2, 0.5)
assert_array_almost_equal(p, [1.00000, 0.50000, -0.12500], 4)
assert_array_almost_equal(pd, [0.00000, 1.00000, 1.50000], 4)

Expand Down

0 comments on commit 3bc2cf1

Please sign in to comment.