Skip to content

Commit

Permalink
ENH Add Array API compatibility to cosine_similarity (#29014)
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed May 17, 2024
1 parent e825502 commit b461547
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 26 deletions.
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Metrics
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.pairwise.cosine_similarity``
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Expand Down
6 changes: 4 additions & 2 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ See :ref:`array_api` for more details.

- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
inputs.
:pr:`28106` by :user:`Thomas Li <lithomas1>`
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`.
:pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.


**Classes:**

Expand Down
11 changes: 10 additions & 1 deletion sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
gen_batches,
gen_even_slices,
)
from ..utils._array_api import (
_find_matching_floating_dtype,
_is_numpy_namespace,
get_namespace,
)
from ..utils._chunking import get_chunk_n_rows
from ..utils._mask import _get_mask
from ..utils._missing import is_scalar_nan
Expand Down Expand Up @@ -154,7 +159,11 @@ def check_pairwise_arrays(
An array equal to Y if Y was not None, guaranteed to be a numpy array.
If Y was None, safe_Y will be a pointer to X.
"""
X, Y, dtype_float = _return_float_dtype(X, Y)
xp, _ = get_namespace(X, Y)
if any([issparse(X), issparse(Y)]) or _is_numpy_namespace(xp):
X, Y, dtype_float = _return_float_dtype(X, Y)
else:
dtype_float = _find_matching_floating_dtype(X, Y, xp=xp)

estimator = "check_pairwise_arrays"
if dtype == "infer_float":
Expand Down
64 changes: 41 additions & 23 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
zero_one_loss,
)
from sklearn.metrics._base import _average_binary_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
Expand Down Expand Up @@ -1743,20 +1744,22 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):


def check_array_api_metric(
metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight
metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs
):
xp = _array_api_for_tests(array_namespace, device)

y_true_xp = xp.asarray(y_true_np, device=device)
y_pred_xp = xp.asarray(y_pred_np, device=device)
a_xp = xp.asarray(a_np, device=device)
b_xp = xp.asarray(b_np, device=device)

metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight)
metric_np = metric(a_np, b_np, **metric_kwargs)

if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, device=device)
if metric_kwargs.get("sample_weight") is not None:
metric_kwargs["sample_weight"] = xp.asarray(
metric_kwargs["sample_weight"], device=device
)

with config_context(array_api_dispatch=True):
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
metric_xp = metric(a_xp, b_xp, **metric_kwargs)

assert_allclose(
_convert_to_numpy(xp.asarray(metric_xp), xp),
Expand All @@ -1776,8 +1779,8 @@ def check_array_api_binary_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1788,8 +1791,8 @@ def check_array_api_binary_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1805,8 +1808,8 @@ def check_array_api_multiclass_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1817,8 +1820,8 @@ def check_array_api_multiclass_classification_metric(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1832,8 +1835,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1844,8 +1847,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1861,8 +1864,8 @@ def check_array_api_regression_metric_multioutput(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)

Expand All @@ -1873,8 +1876,8 @@ def check_array_api_regression_metric_multioutput(
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)

Expand All @@ -1886,6 +1889,20 @@ def check_array_api_multioutput_regression_metric(
check_array_api_regression_metric(metric, array_namespace, device, dtype_name)


def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name):

X_np = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=dtype_name)
Y_np = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], dtype=dtype_name)

metric_kwargs = {}
if "dense_output" in signature(metric).parameters:
metric_kwargs["dense_output"] = True

check_array_api_metric(
metric, array_namespace, device, dtype_name, a_np=X_np, b_np=Y_np
)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand All @@ -1900,6 +1917,7 @@ def check_array_api_multioutput_regression_metric(
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
cosine_similarity: [check_array_api_metric_pairwise],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_multioutput_regression_metric,
Expand Down

0 comments on commit b461547

Please sign in to comment.