Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array API support with lazy evaluation #26724

Open
ogrisel opened this issue Jun 28, 2023 · 12 comments
Open

Array API support with lazy evaluation #26724

ogrisel opened this issue Jun 28, 2023 · 12 comments

Comments

@ogrisel
Copy link
Member

ogrisel commented Jun 28, 2023

At the moment, our Array API integration can implicitly assume eager evaluation semantics.

Furthermore we did not test our code to see how it would behave with candidate Array API implementations with lazy evaluation semantics (e.g. dask, jax, ...).

The purpose of this issue is to track what would be needed to make our Array API estimator work with lazy evaluation.

I think we should first investigate where this breaks, then decide whether or not we would like to support lazy evaluation semantics in scikit-learn via the Array API support and if so open a meta-issue to add common test with an estimator tag and progressively fix the Array API estimators to deal with lazy evaluation.

Note that are particular point about lazy vs eager evaluation for Array API consuming libraries is being discussed here:

@ogrisel
Copy link
Member Author

ogrisel commented Jun 30, 2023

I ran a quick experiment with dask against the PCA branch:

As you can see in the last commit (a8ab0d3), I did some hackish monkey-patches to make it work.

In the end I could successfully call scikit-learn's PCA(n_components=2, power_iteration_normalizer="QR").fit_transform on a dask array and get a dask array as a result and inspect the components_ and explained_variance_ratio_ attributes to check that are also a dask arrays (as exepcted). However calling the .score method causes a freeze (I did not investigate).

Note that there is a meta-issue in dask to track Array API compliance:

I reported some the problems I say during my experiment and linked them to that meta-issue.

@asmeurer
Copy link

I don't know what the status of dask/dask#8750 is, but we can look at adding dask to the compat library if that helps things move faster. On the other hand, it may be unnecessary to do that if they are going to add array_api as a separate namespace because there won't be any backwards compatibility concerns.

@ogrisel
Copy link
Member Author

ogrisel commented Jul 3, 2023

Any opinion on the above @tomwhite / @jrbourbeau? Having a temporary support for dask in array-api-compat while the official dask.array.array_api namespace gets finalized would probably be helpful moving things forward.

But maybe exposing two competing ways to consume dask arrays from array api agnostic code might also introduce unnecessary complexity in the long run.

@tomwhite
Copy link

tomwhite commented Jul 3, 2023

I would favour the array-api-compat route, given that NumPy is going to incorporate the Array API standard into its standard namespace (see dask/dask#8750 (comment), and the NumPy 2 work). This would mean that dask/dask#8750 is not needed at all, avoiding another namespace to support.

The issue for adding Dask support is data-apis/array-api-compat#17, and it should be fairly straightforward to use some of the work in dask/dask#8750 as a guide at least.

@betatim
Copy link
Member

betatim commented Jul 3, 2023

I think explaining to users that a dask.array array won't work and that they have to rewrite their code to use a dask.array.array_api array is difficult. I think this means I am also in support of the array_api_compat route.

@lithomas1
Copy link
Contributor

lithomas1 commented Apr 5, 2024

This was discussed a little in #28588. Here's the writeup detailing our conversation there.

Summary

The major issue with lazy array API implementations in scikit-learn, are that after certain operations (e.g. boolean masking, unique()), the lazy array will end up with a NaN shape, as the shape of the array resulting from a data dependent depend operation is unknown until computation is performed.

This is allowed for in the array API spec, as it allows dimensions to be None in the shape.
(xref https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.shape.html#shape)

However, becuase scikit-learn depends on the shape being a valid numerical value, this can lead to silent correctness problems or outright errors in scikit-learn, depending on what the shape is used for.

As of right now, scikit-learn currently uses .shape in order to:

  1. Allocate a result array with shape matching the shape of the inputs in at least 1 dimension
  2. Using the length of the array to do e.g. normalization
  3. Check shapes, raising a ValueError when the shape does not match expectations

(relevant Github code search: https://github.com/search?q=repo%3Ascikit-learn%2Fscikit-learn+.shape+path%3A%2F%5Esklearn%5C%2F%2F&type=code)

Normally, when writing code for e.g. dask, or any other lazy implementation, users can call compute (or the equivalent function in their library), to perform computation so that the shape is known, before it is needed.

This is not possible with the Array API, though, as there is no method/function to materialize lazy arrays in the Array API.

Furthermore, figuring out when compute needs to be called before/after an operation is non-trivial.

Current failures

  • PCA score method
  • LinearDiscriminantAnalysis
    • Reason: unique is called in this estimator resulting in data-dependent output shapes (represented by nan in dask).
    • Potential solution 1/2 would fix this estimator
  • r2_score
    • Reason: An array with data-dependent output shapes (from masked __getitem__) are passed as operands to dask's __setitem__.
    • Potential solution 2 would probably fix this issue
      • (if we wrap accesses to the shape attribute to materialize, dask would not complain about the nan shapes)
  • is_multilabel
    • Reason: Call to unique made resulting in data-dependent output shapes (represented by nan in dask)
    • Potential solution 1/2 would fix this method

Motivation

Lazy Array API implementations can provide significant speedups by optimizing computation.
(e.g. only compute on the parts of the input array that are needed)

Note: dask implements a limited subset of array optimizations.

Potential Solutions

  1. shape() helper

To be upstreamed into array-api-compat

Calls compute if necessary on the array before returning a non-null shape.

(Although compute is not part of the array API, array-api-compat can call the library specific version of compute for the lazy array implementations that it knows)

Note: This doesn't work in all cases.
Sometimes there are methods, where the inputs needs to be materialized beforehands.

  1. Wrap array classes in array-api-compat and override the shape property and other methods to implicitly materialize their inputs if necessary

@lorentzenchr
Copy link
Member

Which estimator depend on data dependent shapes?

@lithomas1
Copy link
Contributor

LDA does at the very least, there's a call to unique.

This was removed from #28588, since there's another dask bug that would prevent it from working.

@asmeurer
Copy link

asmeurer commented Apr 6, 2024

How does it result in silent correctness problems? Is it because dask uses nan instead of None?

@lithomas1
Copy link
Contributor

Yeah, e.g.

return x.shape < 3

or smth like that from is_multilabel

@ogrisel
Copy link
Member Author

ogrisel commented Apr 11, 2024

All classifiers use xp.unique_* methods to compute self.classes_ and self.n_classes_ which are data-dependent values. The latter is typically used to shape the fitted attribute (e.g. coef_ for linear models) for downstream computation.

@ogrisel
Copy link
Member Author

ogrisel commented Apr 12, 2024

Calls compute if necessary on the array before returning a non-null shape.
(Although compute is not part of the array API, array-api-compat can call the library specific version of compute for the lazy array implementations that it knows)

Note that dask also exposes an alternative to compute: .persist this triggers a compute but does not collect the result back as a numpy array in the driver program. Unfortunately, calling persist does not resolve the shape values.

Anyways, before make a decision, we should have a look at how other array api candidate libraries with lazy evaluation semantics behave on those patterns (shape of unique and masked assignment with a data dependent boolean mask). Here are libraries to investigate:

  • pytorch expressions in a function decorated with @torch.compile
  • jax
  • tensorflow?
  • maybe MLX from Apple.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Todo
Development

No branches or pull requests

6 participants