Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Array API support for f1_score #27369

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

  • Adds array api support for f1_score and the functions related to it.

Any other comments?

CC: @ogrisel @betatim

@github-actions
Copy link

github-actions bot commented Sep 14, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5cd9a11. Link to the linter CI: here

@OmarManzoor OmarManzoor marked this pull request as ready for review May 20, 2024 10:07
@OmarManzoor
Copy link
Contributor Author

@ogrisel Could you kindly have a look at this PR?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good to me. I am surprised it works without being very specific about device and dtypes, but as long as the tests (and they do), I am happy.

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
tp = np.array(tp)
fp = np.array(fp)
fn = np.array(fn)
sample_weight = xp.asarray(sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make sure that it matches the device of the inputs, no? It's curious that existing tests do not fail with PyTorch and MPS device (or cuda devices).

I am also wondering of whether we should convert to a specific dtype. However looking at the tests I never see any case where we pass non-integer sample weights. And even for integer weights, it's only done to check an error message, not to check an actual computation. So I am not sure our sample_weight support is correct, even outside of array API concerns.

I guess this is only indirectly tested by classification metrics that rely on multilabel_confusion_matrix internally. But then the array API compliance tests for F1 score do not fail with floating point weights (I just checked) and I am not sure why.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the output of my cuda run on this PR (updated to check that boolean array indexing also works, but this should be orthogonal):

$ pytest -vlx -k "array_api and f1_score" sklearn/ 

================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0
collected 34881 items / 34863 deselected / 2 skipped / 18 selected                                                                                                                                                       

sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-numpy-None-None] PASSED                                                                      [  5%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-array_api_strict-None-None] PASSED                                                           [ 11%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy-None-None] PASSED                                                                       [ 16%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy.array_api-None-None] PASSED                                                             [ 22%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED                                                                    [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED                                                                    [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED                                                                   [ 38%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED                                                                   [ 44%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK...) [ 50%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-numpy-None-None] PASSED                                                                  [ 55%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-array_api_strict-None-None] PASSED                                                       [ 61%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy-None-None] PASSED                                                                   [ 66%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy.array_api-None-None] PASSED                                                         [ 72%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED                                                                [ 77%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED                                                                [ 83%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED                                                               [ 88%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED                                                               [ 94%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALL...) [100%]

============================================================================= 16 passed, 4 skipped, 34863 deselected, 105 warnings in 15.59s =============================================================================

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
@OmarManzoor
Copy link
Contributor Author

@ogrisel @betatim Does this look okay now?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not have time to finish the review today but here is some quick feedback:

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Show resolved Hide resolved
@ogrisel
Copy link
Member

ogrisel commented Jun 5, 2024

I merged main to be able to launch the new CUDA GPU CI workflow on this PR:

EDIT: tests are green.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the following is addressed:

sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
@ogrisel
Copy link
Member

ogrisel commented Jun 6, 2024

@betatim this is ready for a second review.

@ogrisel ogrisel added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jun 6, 2024
@ogrisel
Copy link
Member

ogrisel commented Jun 6, 2024

I launched the CUDA GPU CI at:

EDIT: CUDA tests are green.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

None yet

2 participants