Skip to content

Commit

Permalink
ENH Array API support for LabelEncoder (#27381)
Browse files Browse the repository at this point in the history
Co-authored-by: Omar Salman <omar.salman@arbisoft>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
3 people committed May 16, 2024
1 parent 945273d commit acd2d90
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 38 deletions.
3 changes: 2 additions & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ See :ref:`array_api` for more details.

**Classes:**

-
- :class:`preprocessing.LabelEncoder` now supports Array API compatible inputs.
:pr:`27381` by :user:`Omar Salman <OmarManzoor>`.

Metadata Routing
----------------
Expand Down
23 changes: 15 additions & 8 deletions sklearn/preprocessing/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..base import BaseEstimator, TransformerMixin, _fit_context
from ..utils import column_or_1d
from ..utils._array_api import _setdiff1d, device, get_namespace
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Interval, validate_params
from ..utils.multiclass import type_of_target, unique_labels
Expand Down Expand Up @@ -129,10 +130,11 @@ def transform(self, y):
Labels as normalized encodings.
"""
check_is_fitted(self)
xp, _ = get_namespace(y)
y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)
# transform of empty array is empty array
if _num_samples(y) == 0:
return np.array([])
return xp.asarray([])

return _encode(y, uniques=self.classes_)

Expand All @@ -141,7 +143,7 @@ def inverse_transform(self, y):
Parameters
----------
y : ndarray of shape (n_samples,)
y : array-like of shape (n_samples,)
Target values.
Returns
Expand All @@ -150,19 +152,24 @@ def inverse_transform(self, y):
Original encoding.
"""
check_is_fitted(self)
xp, _ = get_namespace(y)
y = column_or_1d(y, warn=True)
# inverse transform of empty array is empty array
if _num_samples(y) == 0:
return np.array([])
return xp.asarray([])

diff = np.setdiff1d(y, np.arange(len(self.classes_)))
if len(diff):
diff = _setdiff1d(
ar1=y,
ar2=xp.arange(self.classes_.shape[0], device=device(y)),
xp=xp,
)
if diff.shape[0]:
raise ValueError("y contains previously unseen labels: %s" % str(diff))
y = np.asarray(y)
return self.classes_[y]
y = xp.asarray(y)
return xp.take(self.classes_, y, axis=0)

def _more_tags(self):
return {"X_types": ["1dlabels"]}
return {"X_types": ["1dlabels"], "array_api_support": True}


class LabelBinarizer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
Expand Down
52 changes: 50 additions & 2 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from scipy.sparse import issparse

from sklearn import datasets
from sklearn import config_context, datasets
from sklearn.preprocessing._label import (
LabelBinarizer,
LabelEncoder,
Expand All @@ -11,7 +11,16 @@
_inverse_binarize_thresholding,
label_binarize,
)
from sklearn.utils._testing import assert_array_equal, ignore_warnings
from sklearn.utils._array_api import (
_convert_to_numpy,
get_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
_array_api_for_tests,
assert_array_equal,
ignore_warnings,
)
from sklearn.utils.fixes import (
COO_CONTAINERS,
CSC_CONTAINERS,
Expand Down Expand Up @@ -697,3 +706,42 @@ def test_label_encoders_do_not_have_set_output(encoder):
y_encoded_with_kwarg = encoder.fit_transform(y=["a", "b", "c"])
y_encoded_positional = encoder.fit_transform(["a", "b", "c"])
assert_array_equal(y_encoded_with_kwarg, y_encoded_positional)


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize(
"y",
[
np.array([2, 1, 3, 1, 3]),
np.array([1, 1, 4, 5, -1, 0]),
np.array([3, 5, 9, 5, 9, 3]),
],
)
def test_label_encoder_array_api_compliance(y, array_namespace, device, dtype):
xp = _array_api_for_tests(array_namespace, device)
xp_y = xp.asarray(y, device=device)
with config_context(array_api_dispatch=True):
xp_label = LabelEncoder()
np_label = LabelEncoder()
xp_label = xp_label.fit(xp_y)
xp_transformed = xp_label.transform(xp_y)
xp_inv_transformed = xp_label.inverse_transform(xp_transformed)
np_label = np_label.fit(y)
np_transformed = np_label.transform(y)
assert get_namespace(xp_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_inv_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_label.classes_)[0].__name__ == xp.__name__
assert_array_equal(_convert_to_numpy(xp_transformed, xp), np_transformed)
assert_array_equal(_convert_to_numpy(xp_inv_transformed, xp), y)
assert_array_equal(_convert_to_numpy(xp_label.classes_, xp), np_label.classes_)

xp_label = LabelEncoder()
np_label = LabelEncoder()
xp_transformed = xp_label.fit_transform(xp_y)
np_transformed = np_label.fit_transform(y)
assert get_namespace(xp_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_label.classes_)[0].__name__ == xp.__name__
assert_array_equal(_convert_to_numpy(xp_transformed, xp), np_transformed)
assert_array_equal(_convert_to_numpy(xp_label.classes_, xp), np_label.classes_)
123 changes: 123 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def unique_counts(self, x):
def unique_values(self, x):
return numpy.unique(x)

def unique_all(self, x):
return numpy.unique(
x, return_index=True, return_inverse=True, return_counts=True
)

def concat(self, arrays, *, axis=None):
return numpy.concatenate(arrays, axis=axis)

Expand Down Expand Up @@ -839,3 +844,121 @@ def indexing_dtype(xp):
# TODO: once sufficiently adopted, we might want to instead rely on the
# newer inspection API: https://github.com/data-apis/array-api/issues/640
return xp.asarray(0).dtype


def _searchsorted(xp, a, v, *, side="left", sorter=None):
# Temporary workaround needed as long as searchsorted is not widely
# adopted by implementers of the Array API spec. This is a quite
# recent addition to the spec:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html # noqa
if hasattr(xp, "searchsorted"):
return xp.searchsorted(a, v, side=side, sorter=sorter)

a_np = _convert_to_numpy(a, xp=xp)
v_np = _convert_to_numpy(v, xp=xp)
indices = numpy.searchsorted(a_np, v_np, side=side, sorter=sorter)
return xp.asarray(indices, device=device(a))


def _setdiff1d(ar1, ar2, xp, assume_unique=False):
"""Find the set difference of two arrays.
Return the unique values in `ar1` that are not in `ar2`.
"""
if _is_numpy_namespace(xp):
return xp.asarray(
numpy.setdiff1d(
ar1=ar1,
ar2=ar2,
assume_unique=assume_unique,
)
)

if assume_unique:
ar1 = xp.reshape(ar1, (-1,))
else:
ar1 = xp.unique_values(ar1)
ar2 = xp.unique_values(ar2)
return ar1[_in1d(ar1=ar1, ar2=ar2, xp=xp, assume_unique=True, invert=True)]


def _isin(element, test_elements, xp, assume_unique=False, invert=False):
"""Calculates ``element in test_elements``, broadcasting over `element`
only.
Returns a boolean array of the same shape as `element` that is True
where an element of `element` is in `test_elements` and False otherwise.
"""
if _is_numpy_namespace(xp):
return xp.asarray(
numpy.isin(
element=element,
test_elements=test_elements,
assume_unique=assume_unique,
invert=invert,
)
)

original_element_shape = element.shape
element = xp.reshape(element, (-1,))
test_elements = xp.reshape(test_elements, (-1,))
return xp.reshape(
_in1d(
ar1=element,
ar2=test_elements,
xp=xp,
assume_unique=assume_unique,
invert=invert,
),
original_element_shape,
)


# Note: This is a helper for the functions `_isin` and
# `_setdiff1d`. It is not meant to be called directly.
def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
"""Checks whether each element of an array is also present in a
second array.
Returns a boolean array the same length as `ar1` that is True
where an element of `ar1` is in `ar2` and False otherwise.
This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""
xp, _ = get_namespace(ar1, ar2, xp=xp)

# This code is run to make the code significantly faster
if ar2.shape[0] < 10 * ar1.shape[0] ** 0.145:
if invert:
mask = xp.ones(ar1.shape[0], dtype=xp.bool, device=device(ar1))
for a in ar2:
mask &= ar1 != a
else:
mask = xp.zeros(ar1.shape[0], dtype=xp.bool, device=device(ar1))
for a in ar2:
mask |= ar1 == a
return mask

if not assume_unique:
ar1, rev_idx = xp.unique_inverse(ar1)
ar2 = xp.unique_values(ar2)

ar = xp.concat((ar1, ar2))
device_ = device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
if invert:
bool_ar = sar[1:] != sar[:-1]
else:
bool_ar = sar[1:] == sar[:-1]
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: ar1.shape[0]]
else:
return xp.take(ret, rev_idx, axis=0)

0 comments on commit acd2d90

Please sign in to comment.