Skip to content

Commit

Permalink
MAINT: Windows distutils cdist/pdist shims (#18169)
Browse files Browse the repository at this point in the history
* MAINT: Windows distutils cdist/pdist shims

* the `out` and weight parameter input validation for `cdist`/`pdist` fails
in a very narrow set of scenarios on Windows--only
if you use the old `distutils`-based build system
(on Azure) and only in the context of the broader
testsuite (not if you select the test individually
with `pytest ... -k "..."` it seems)

* so, on Windows only, add temporary Python-based
guards for the `out` / `w` parameters, to compensate for the
mysterious failure to `raise` within `pybind11` machinery;
the comments should suffice to remove this later--I'm sure
we'll grep for `distutils` when we do the purge from that
build system

* deal with linter complaints about old-school
type hints

---------

Co-authored-by: peterbell10 <peterbell10@live.co.uk>
  • Loading branch information
tylerjereddy and peterbell10 committed Mar 23, 2023
1 parent 2397d2f commit a72a392
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
27 changes: 24 additions & 3 deletions scipy/spatial/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@
]


import os
import warnings
import numpy as np
import dataclasses

from typing import List, Optional, Set, Callable
from typing import Optional, Callable

from functools import partial
from scipy._lib._util import _asarray_validated
Expand All @@ -122,6 +123,24 @@
from . import _distance_pybind


def _extra_windows_error_checks(x, out, required_shape, **kwargs):
# TODO: remove this function when distutils
# build system is removed because pybind11 error
# handling should suffice per gh-18108
if os.name == "nt" and out is not None:
if out.shape != required_shape:
raise ValueError("Output array has incorrect shape.")
if not out.flags["C_CONTIGUOUS"]:
raise ValueError("Output array must be C-contiguous.")
if not np.can_cast(x.dtype, out.dtype):
raise ValueError("Wrong out dtype.")
if os.name == "nt" and "w" in kwargs:
w = kwargs["w"]
if w is not None:
if (w < 0).sum() > 0:
raise ValueError("Input weights should be all non-negative")


def _copy_array_if_base_present(a):
"""Copy the array if its base points to a parent array."""
if a.base is not None:
Expand Down Expand Up @@ -1682,7 +1701,7 @@ class MetricInfo:
# Name of python distance function
canonical_name: str
# All aliases, including canonical_name
aka: Set[str]
aka: set[str]
# unvectorized distance function
dist_func: Callable
# Optimized cdist function
Expand All @@ -1695,7 +1714,7 @@ class MetricInfo:
# list of supported types:
# X (pdist) and XA (cdist) are used to choose the type. if there is no
# match the first type is used. Default double
types: List[str] = dataclasses.field(default_factory=lambda: ['double'])
types: list[str] = dataclasses.field(default_factory=lambda: ['double'])
# true if out array must be C-contiguous
requires_contiguous_out: bool = True

Expand Down Expand Up @@ -2190,6 +2209,7 @@ def pdist(X, metric='euclidean', *, out=None, **kwargs):

if metric_info is not None:
pdist_fn = metric_info.pdist_func
_extra_windows_error_checks(X, out, (m * (m - 1) / 2,), **kwargs)
return pdist_fn(X, out=out, **kwargs)
elif mstr.startswith("test_"):
metric_info = _TEST_METRICS.get(mstr, None)
Expand Down Expand Up @@ -2975,6 +2995,7 @@ def cdist(XA, XB, metric='euclidean', *, out=None, **kwargs):
metric_info = _METRIC_ALIAS.get(mstr, None)
if metric_info is not None:
cdist_fn = metric_info.cdist_func
_extra_windows_error_checks(XA, out, (mA, mB), **kwargs)
return cdist_fn(XA, XB, out=out, **kwargs)
elif mstr.startswith("test_"):
metric_info = _TEST_METRICS.get(mstr, None)
Expand Down
3 changes: 0 additions & 3 deletions scipy/spatial/tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ def test_cdist_dtype_equivalence(self):
y2 = cdist(new_type(X1), new_type(X2), metric=metric)
assert_allclose(y1, y2, rtol=eps, verbose=verbose > 2)

@pytest.mark.skip("Failing on Windows Azure jobs; see gh-18108.")
def test_cdist_out(self):
# Test that out parameter works properly
eps = 1e-15
Expand Down Expand Up @@ -1469,7 +1468,6 @@ def test_pdist_dtype_equivalence(self):
y2 = pdist(new_type(X1), metric=metric)
assert_allclose(y1, y2, rtol=eps, verbose=verbose > 2)

@pytest.mark.skip("Failing on Windows Azure jobs; see gh-18108.")
def test_pdist_out(self):
# Test that out parameter works properly
eps = 1e-15
Expand Down Expand Up @@ -2073,7 +2071,6 @@ def test_Xdist_deprecated_args():
pdist(X1, metric, **kwargs)


@pytest.mark.skip("Failing on Windows Azure jobs; see gh-18108.")
def test_Xdist_non_negative_weights():
X = eo['random-float32-data'][::5, ::2]
w = np.ones(X.shape[1])
Expand Down

0 comments on commit a72a392

Please sign in to comment.