Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats.monte_carlo_test: add array API support #20604

Merged
merged 10 commits into from
May 5, 2024
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ jobs:
python dev.py --no-build test -b all -t scipy._lib.tests.test__util -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_stats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_morestats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_resampling -- --durations 3 --timeout=60
22 changes: 21 additions & 1 deletion scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,24 @@ def xp_minimum(x1, x2):
res = xp.asarray(x1, copy=True, dtype=dtype)
i = (x2 < x1) | xp.isnan(x2)
res[i] = x2[i]
return res
return res[()] if res.ndim == 0 else res


# temporary substitute for xp.clip, which is not yet in all backends
# or covered by array_api_compat.
def xp_clip(x, a, b, xp=None):
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
xp = array_namespace(xp) if xp is None else xp
y = xp.asarray(x, copy=True)
y[y < a] = a
y[y > b] = b
return y[()] if y.ndim == 0 else y


# temporary substitute for xp.moveaxis, which is not yet in all backends
# or covered by array_api_compat.
def _move_axis_to_end(x, source, xp=None):
rgommers marked this conversation as resolved.
Show resolved Hide resolved
xp = array_namespace(xp) if xp is None else xp
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)
23 changes: 9 additions & 14 deletions scipy/stats/_axis_nan_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,20 @@
import inspect


def _broadcast_arrays(arrays, axis=None):
def _broadcast_arrays(arrays, axis=None, xp=None):
"""
Broadcast shapes of arrays, ignoring incompatibility of specified axes
"""
new_shapes = _broadcast_array_shapes(arrays, axis=axis)
xp = array_namespace(*arrays) if xp is None else xp
arrays = [xp.asarray(arr) for arr in arrays]
shapes = [arr.shape for arr in arrays]
new_shapes = _broadcast_shapes(shapes, axis)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
if axis is None:
new_shapes = [new_shapes]*len(arrays)
return [np.broadcast_to(array, new_shape)
return [xp.broadcast_to(array, new_shape)
for array, new_shape in zip(arrays, new_shapes)]


def _broadcast_array_shapes(arrays, axis=None):
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
"""
Broadcast shapes of arrays, ignoring incompatibility of specified axes
"""
shapes = [np.asarray(arr).shape for arr in arrays]
return _broadcast_shapes(shapes, axis)


def _broadcast_shapes(shapes, axis=None):
"""
Broadcast shapes, ignoring incompatibility of specified axes
Expand Down Expand Up @@ -103,10 +98,10 @@ def _broadcast_array_shapes_remove_axis(arrays, axis=None):
Examples
--------
>>> import numpy as np
>>> from scipy.stats._axis_nan_policy import _broadcast_array_shapes
>>> from scipy.stats._axis_nan_policy import _broadcast_array_shapes_remove_axis
>>> a = np.zeros((5, 2, 1))
>>> b = np.zeros((9, 3))
>>> _broadcast_array_shapes((a, b), 1)
>>> _broadcast_array_shapes_remove_axis((a, b), 1)
(5, 3)
"""
# Note that here, `axis=None` means do not consume/drop any axes - _not_
Expand All @@ -119,7 +114,7 @@ def _broadcast_shapes_remove_axis(shapes, axis=None):
"""
Broadcast shapes, dropping specified axes

Same as _broadcast_array_shapes, but given a sequence
Same as _broadcast_array_shapes_remove_axis, but given a sequence
of array shapes `shapes` instead of the arrays themselves.
"""
shapes = _broadcast_shapes(shapes, axis)
Expand Down
73 changes: 50 additions & 23 deletions scipy/stats/_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import numpy as np
from itertools import combinations, permutations, product
from collections.abc import Sequence
from dataclasses import dataclass
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
import inspect

from scipy._lib._util import check_random_state, _rename_parameter
from scipy._lib._util import check_random_state, _rename_parameter, rng_integers
from scipy._lib._array_api import (array_namespace, is_numpy, xp_minimum,
xp_clip, _move_axis_to_end)
from scipy.special import ndtr, ndtri, comb, factorial
from scipy._lib._util import rng_integers
from dataclasses import dataclass

from ._common import ConfidenceInterval
from ._axis_nan_policy import _broadcast_concatenate, _broadcast_arrays
from ._warnings_errors import DegenerateDataWarning
Expand Down Expand Up @@ -662,7 +664,6 @@ def percentile_fun(a, q):
def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
batch, alternative, axis):
"""Input validation for `monte_carlo_test`."""

axis_int = int(axis)
if axis != axis_int:
raise ValueError("`axis` must be an integer.")
Expand All @@ -677,26 +678,45 @@ def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
if not callable(rvs_i):
raise TypeError("`rvs` must be callable or sequence of callables.")

# At this point, `data` should be a sequence
# If it isn't, the user passed a sequence for `rvs` but not `data`
message = "If `rvs` is a sequence, `len(rvs)` must equal `len(data)`."
try:
len(data)
except TypeError as e:
raise ValueError(message) from e
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
if not len(rvs) == len(data):
message = "If `rvs` is a sequence, `len(rvs)` must equal `len(data)`."
raise ValueError(message)

if not callable(statistic):
raise TypeError("`statistic` must be callable.")

if vectorized is None:
vectorized = 'axis' in inspect.signature(statistic).parameters
try:
signature = inspect.signature(statistic).parameters
except ValueError as e:
message = (f"Signature inspection of {statistic=} failed; "
"pass `vectorize` explicitly.")
raise ValueError(message) from e
vectorized = 'axis' in signature
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved

xp = array_namespace(*data)

if not vectorized:
statistic_vectorized = _vectorize_statistic(statistic)
if is_numpy(xp):
statistic_vectorized = _vectorize_statistic(statistic)
else:
message = ("`statistic` must be vectorized (i.e. support an `axis` "
f"argument) when `data` contains {xp.__name__} arrays.")
raise ValueError(message)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
else:
statistic_vectorized = statistic

data = _broadcast_arrays(data, axis)
data = _broadcast_arrays(data, axis, xp=xp)
data_iv = []
for sample in data:
sample = np.atleast_1d(sample)
sample = np.moveaxis(sample, axis_int, -1)
sample = xp.broadcast_to(sample, (1,)) if sample.ndim == 0 else sample
sample = _move_axis_to_end(sample, axis_int, xp=xp)
data_iv.append(sample)

n_resamples_int = int(n_resamples)
Expand All @@ -715,8 +735,12 @@ def _monte_carlo_test_iv(data, rvs, statistic, vectorized, n_resamples,
if alternative not in alternatives:
raise ValueError(f"`alternative` must be in {alternatives}")

# Infer the desired p-value dtype based on the input types
min_float = getattr(xp, 'float16', xp.float32)
dtype = xp.result_type(*data_iv, min_float)

return (data_iv, rvs, statistic_vectorized, vectorized, n_resamples_int,
batch_iv, alternative, axis_int)
batch_iv, alternative, axis_int, dtype, xp)


@dataclass
Expand Down Expand Up @@ -908,11 +932,12 @@ def monte_carlo_test(data, rvs, statistic, *, vectorized=None,
"""
args = _monte_carlo_test_iv(data, rvs, statistic, vectorized,
n_resamples, batch, alternative, axis)
(data, rvs, statistic, vectorized,
n_resamples, batch, alternative, axis) = args
(data, rvs, statistic, vectorized, n_resamples,
batch, alternative, axis, dtype, xp) = args

# Some statistics return plain floats; ensure they're at least a NumPy float
observed = np.asarray(statistic(*data, axis=-1))[()]
observed = xp.asarray(statistic(*data, axis=-1))
observed = observed[()] if observed.ndim == 0 else observed

n_observations = [sample.shape[-1] for sample in data]
batch_nominal = batch or n_resamples
Expand All @@ -922,37 +947,39 @@ def monte_carlo_test(data, rvs, statistic, *, vectorized=None,
resamples = [rvs_i(size=(batch_actual, n_observations_i))
for rvs_i, n_observations_i in zip(rvs, n_observations)]
null_distribution.append(statistic(*resamples, axis=-1))
null_distribution = np.concatenate(null_distribution)
null_distribution = null_distribution.reshape([-1] + [1]*observed.ndim)
null_distribution = xp.concat(null_distribution)
null_distribution = xp.reshape(null_distribution, [-1] + [1]*observed.ndim)

# relative tolerance for detecting numerically distinct but
# theoretically equal values in the null distribution
eps = (0 if not np.issubdtype(observed.dtype, np.inexact)
else np.finfo(observed.dtype).eps*100)
gamma = np.abs(eps * observed)
eps = (0 if not xp.isdtype(observed.dtype, ('real floating'))
else xp.finfo(observed.dtype).eps*100)
gamma = xp.abs(eps * observed)

def less(null_distribution, observed):
cmps = null_distribution <= observed + gamma
pvalues = (cmps.sum(axis=0) + 1) / (n_resamples + 1) # see [1]
cmps = xp.asarray(cmps, dtype=dtype)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
pvalues = (xp.sum(cmps, axis=0, dtype=dtype) + 1.) / (n_resamples + 1.)
return pvalues

def greater(null_distribution, observed):
cmps = null_distribution >= observed - gamma
pvalues = (cmps.sum(axis=0) + 1) / (n_resamples + 1) # see [1]
cmps = xp.asarray(cmps, dtype=dtype)
pvalues = (xp.sum(cmps, axis=0, dtype=dtype) + 1.) / (n_resamples + 1.)
return pvalues

def two_sided(null_distribution, observed):
pvalues_less = less(null_distribution, observed)
pvalues_greater = greater(null_distribution, observed)
pvalues = np.minimum(pvalues_less, pvalues_greater) * 2
pvalues = xp_minimum(pvalues_less, pvalues_greater) * 2
return pvalues

compare = {"less": less,
"greater": greater,
"two-sided": two_sided}

pvalues = compare[alternative](null_distribution, observed)
pvalues = np.clip(pvalues, 0, 1)
pvalues = xp_clip(pvalues, 0., 1., xp=xp)

return MonteCarloTestResult(observed, pvalues, null_distribution)

Expand Down
19 changes: 3 additions & 16 deletions scipy/stats/_stats_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from scipy import stats
from scipy.optimize import root_scalar
from scipy._lib._util import normalize_axis_index
from scipy._lib._array_api import array_namespace, is_numpy, atleast_nd
from scipy._lib._array_api import (array_namespace, is_numpy, atleast_nd,
xp_clip, _move_axis_to_end)
from scipy._lib.array_api_compat import size as xp_size

# In __all__ but deprecated for removal in SciPy 1.13.0
Expand Down Expand Up @@ -4535,20 +4536,6 @@ def confidence_interval(self, confidence_level=0.95, method=None):
return ci


def _move_axis_to_end(x, source, xp):
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)


def _clip(x, a, b, xp):
y = xp.asarray(x, copy=True)
y[y < a] = a
y[y > b] = b
return y


def pearsonr(x, y, *, alternative='two-sided', method=None, axis=0):
r"""
Pearson correlation coefficient and p-value for testing non-correlation.
Expand Down Expand Up @@ -4934,7 +4921,7 @@ def statistic(x, y, axis):
one = xp.asarray(1, dtype=dtype)
# `clip` only recently added to array API, so it's not yet available in
# array_api_strict. Replace with e.g. `xp.clip(r, -one, one)` when available.
r = xp.asarray(_clip(r, -one, one, xp))
r = xp.asarray(xp_clip(r, -one, one, xp))
r[const_xy] = xp.nan

# As explained in the docstring, the distribution of `r` under the null
Expand Down