Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array API support for cross_validation and friends #28677

Open
ogrisel opened this issue Mar 21, 2024 · 1 comment
Open

Array API support for cross_validation and friends #28677

ogrisel opened this issue Mar 21, 2024 · 1 comment

Comments

@ogrisel
Copy link
Member

ogrisel commented Mar 21, 2024

Now that #28407 was merged, we need to adopt other cross-validation and model selection tools, starting with cross_validate. Currently it fails with:

import array_api_strict as xp

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from sklearn import set_config
set_config(array_api_dispatch=True)

X, y = make_classification()
cross_validate(LinearDiscriminantAnalysis(), xp.asarray(X), xp.asarray(y))
/Users/ogrisel/code/scikit-learn/sklearn/utils/validation.py:109: UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. Such cross-library comparison is not supported by the standard.
  if X.dtype == np.dtype("object") and not allow_nan:
/Users/ogrisel/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_strict/_indexing_functions.py:16: UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. Such cross-library comparison is not supported by the standard.
  if indices.dtype not in _integer_dtypes:
Traceback (most recent call last):
  Cell In[14], line 10
    cross_validate(LinearDiscriminantAnalysis(), xp.asarray(X), xp.asarray(y))
  File ~/code/scikit-learn/sklearn/utils/_param_validation.py:213 in wrapper
    return func(*args, **kwargs)
  File ~/code/scikit-learn/sklearn/model_selection/_validation.py:423 in cross_validate
    results = parallel(
  File ~/code/scikit-learn/sklearn/utils/parallel.py:67 in __call__
    return super().__call__(iterable_with_config)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/joblib/parallel.py:1863 in __call__
    return output if self.return_generator else list(output)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/joblib/parallel.py:1792 in _get_sequential_output
    res = func(*args, **kwargs)
  File ~/code/scikit-learn/sklearn/utils/parallel.py:129 in __call__
    return self.function(*args, **kwargs)
  File ~/code/scikit-learn/sklearn/model_selection/_validation.py:880 in _fit_and_score
    X_train, y_train = _safe_split(estimator, X, y, train)
  File ~/code/scikit-learn/sklearn/utils/metaestimators.py:158 in _safe_split
    X_subset = _safe_indexing(X, indices)
  File ~/code/scikit-learn/sklearn/utils/_indexing.py:264 in _safe_indexing
    return _array_indexing(X, indices, indices_dtype, axis=axis)
  File ~/code/scikit-learn/sklearn/utils/_indexing.py:27 in _array_indexing
    return xp.take(array, key, axis=axis)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_strict/_indexing_functions.py:17 in take
    raise TypeError("Only integer dtypes are allowed in indexing")
TypeError: Only integer dtypes are allowed in indexing

Extra notes:

  • currently the error message we get when passing Array API inputs to functions or methods that have not been updated to support the Array API are very low level and non informative.
  • the UserWarning raised at sklearn/utils/validation.py:109 is not specific to cross-validation. It is also raised when calling LinearDiscriminantAnalysis().fit(xp.asarray(X), xp.asarray(y)) directly.
@lorentzenchr
Copy link
Member

In ordere to minimize efforts, I‘d suggest to not add more informative errors, but to add array API support (which should become more and more easy, right?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants