-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Array API support for f1_score #27369
base: main
Are you sure you want to change the base?
Conversation
@ogrisel Could you kindly have a look at this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good to me. I am surprised it works without being very specific about device and dtypes, but as long as the tests (and they do), I am happy.
sklearn/metrics/_classification.py
Outdated
tp = np.array(tp) | ||
fp = np.array(fp) | ||
fn = np.array(fn) | ||
sample_weight = xp.asarray(sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably make sure that it matches the device of the inputs, no? It's curious that existing tests do not fail with PyTorch and MPS device (or cuda devices).
I am also wondering of whether we should convert to a specific dtype. However looking at the tests I never see any case where we pass non-integer sample weights. And even for integer weights, it's only done to check an error message, not to check an actual computation. So I am not sure our sample_weight
support is correct, even outside of array API concerns.
I guess this is only indirectly tested by classification metrics that rely on multilabel_confusion_matrix
internally. But then the array API compliance tests for F1 score do not fail with floating point weights (I just checked) and I am not sure why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the output of my cuda run on this PR (updated to check that boolean array indexing also works, but this should be orthogonal):
$ pytest -vlx -k "array_api and f1_score" sklearn/
================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0
collected 34881 items / 34863 deselected / 2 skipped / 18 selected
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-numpy-None-None] PASSED [ 5%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-array_api_strict-None-None] PASSED [ 11%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy-None-None] PASSED [ 16%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy.array_api-None-None] PASSED [ 22%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED [ 38%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED [ 44%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK...) [ 50%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-numpy-None-None] PASSED [ 55%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-array_api_strict-None-None] PASSED [ 61%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy-None-None] PASSED [ 66%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy.array_api-None-None] PASSED [ 72%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED [ 77%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED [ 83%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED [ 88%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED [ 94%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALL...) [100%]
============================================================================= 16 passed, 4 skipped, 34863 deselected, 105 warnings in 15.59s =============================================================================
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not have time to finish the review today but here is some quick feedback:
I merged EDIT: tests are green. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the following is addressed:
@betatim this is ready for a second review. |
I launched the CUDA GPU CI at: EDIT: CUDA tests are green. |
Reference Issues/PRs
Towards #26024
What does this implement/fix? Explain your changes.
Any other comments?
CC: @ogrisel @betatim