You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
/Users/ogrisel/code/scikit-learn/sklearn/utils/validation.py:109: UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. Such cross-library comparison is not supported by the standard.
if X.dtype == np.dtype("object") and not allow_nan:
/Users/ogrisel/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_strict/_indexing_functions.py:16: UserWarning: You are comparing a array_api_strict dtype against a NumPy native dtype object, but you probably don't want to do this. array_api_strict dtype objects compare unequal to their NumPy equivalents. Such cross-library comparison is not supported by the standard.
if indices.dtype not in _integer_dtypes:
Traceback (most recent call last):
Cell In[14], line 10
cross_validate(LinearDiscriminantAnalysis(), xp.asarray(X), xp.asarray(y))
File ~/code/scikit-learn/sklearn/utils/_param_validation.py:213 in wrapper
return func(*args, **kwargs)
File ~/code/scikit-learn/sklearn/model_selection/_validation.py:423 in cross_validate
results = parallel(
File ~/code/scikit-learn/sklearn/utils/parallel.py:67 in __call__
returnsuper().__call__(iterable_with_config)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/joblib/parallel.py:1863 in __call__
return output ifself.return_generator elselist(output)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/joblib/parallel.py:1792 in _get_sequential_output
res = func(*args, **kwargs)
File ~/code/scikit-learn/sklearn/utils/parallel.py:129 in __call__
returnself.function(*args, **kwargs)
File ~/code/scikit-learn/sklearn/model_selection/_validation.py:880 in _fit_and_score
X_train, y_train = _safe_split(estimator, X, y, train)
File ~/code/scikit-learn/sklearn/utils/metaestimators.py:158 in _safe_split
X_subset = _safe_indexing(X, indices)
File ~/code/scikit-learn/sklearn/utils/_indexing.py:264 in _safe_indexing
return _array_indexing(X, indices, indices_dtype, axis=axis)
File ~/code/scikit-learn/sklearn/utils/_indexing.py:27 in _array_indexing
return xp.take(array, key, axis=axis)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_strict/_indexing_functions.py:17 in take
raiseTypeError("Only integer dtypes are allowed in indexing")
TypeError: Only integer dtypes are allowed in indexing
Extra notes:
currently the error message we get when passing Array API inputs to functions or methods that have not been updated to support the Array API are very low level and non informative.
the UserWarning raised at sklearn/utils/validation.py:109 is not specific to cross-validation. It is also raised when calling LinearDiscriminantAnalysis().fit(xp.asarray(X), xp.asarray(y)) directly.
The text was updated successfully, but these errors were encountered:
In ordere to minimize efforts, I‘d suggest to not add more informative errors, but to add array API support (which should become more and more easy, right?).
Now that #28407 was merged, we need to adopt other cross-validation and model selection tools, starting with
cross_validate
. Currently it fails with:Extra notes:
UserWarning
raised atsklearn/utils/validation.py:109
is not specific to cross-validation. It is also raised when callingLinearDiscriminantAnalysis().fit(xp.asarray(X), xp.asarray(y))
directly.The text was updated successfully, but these errors were encountered: