Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: _array_api.Generator: unified RNG interface #20549

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Apr 22, 2024

Reference issue

gh-18286

What does this implement/fix?

There is no random extension for the array API, but we need to generate random numbers of each array type in tests and in some stats functions. The PR drafts a unfied interface to random number generation for supported backends. I've tested end-to-end with NumPy and PyTorch locally, and I've tested the subclasses for CuPy and JAX on Colab (but it would be great if somebody could run the test suite for them).

Additional information

I'm not sure if this is what we want, but I figured the best way to come to a consensus was to put something concrete out there.

I considered calling it xp_Generator... not sure how to balance the desire to put xp in front of it against wanting a class name to start with a capital letter.

Just as special provides a unified interface to some functions before we have a special function extension of the array API, it would be useful to have a unified interface to RNGs, even if the stateful generator object goes against the paradigm of some libraries. Moreover, we will need a way for users to specify the RNG to use in functions like scipy.stats.bootstrap. This should be private initially, but it's worth considering making it public at some point.

Example:

import numpy as np
import torch
from scipy._lib._array_api import Generator

rng = Generator(numpy, seed=2259842524510)
rng.random()  # 0.2544984223225736

rng = Generator(torch)
rng.random((2, 3), torch.float64)
# tensor([[0.9019, 0.8082, 0.6819],
#         [0.9434, 0.3858, 0.6348]], dtype=torch.float64)

Of course, we can add additional methods and options (e.g. device, a way to instantiate from an existing backend-specific generator object, etc.) as needed; this would just lay the groundwork.

@mdhaber mdhaber added enhancement A new feature or improvement array types Items related to array API support and input array validation (see gh-18286) labels Apr 22, 2024
Comment on lines +369 to +370
# elif is_jax(xp):
# return super().__new__(_Generator_jax)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncomment after gh-20085 merges.

@mdhaber mdhaber marked this pull request as ready for review April 22, 2024 07:04
@rkern
Copy link
Member

rkern commented Apr 22, 2024

If it's just to internally have some kind of pseudorandomness that outputs the xp-appropriate array type, I'd probably suggest just wrapping a plain np.random.Generator that converts the np arrays to the appropriate xp array and eat the cost of the memory copy. That's probably good enough for scipy's internal uses, and gives you access to the whole np.random.Generator API surface.

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 22, 2024

Sure, that is what we do right now for tests, so I considered it. However, the resampling functions need only a handful of RNG methods, and we're trying to avoid conversions between array formats and especially device transfers (whether they appear to be the bottleneck or not, it seems), so I thought it would be worth trying this. Fine with me if there is a strong consensus for always using NumPy generators and converting as necessary. That said, I wish that something like this were available somewhere if the array-API isn't going to standardize it (or until it does).

@lucascolley
Copy link
Member

either way seems fine for now - definitely nice to have this work in case we want/need it in the future

@ev-br
Copy link
Member

ev-br commented Apr 22, 2024

Please add a note that the random streams may differ for different backends, even for the same seed.

@rgommers
Copy link
Member

I think this is worth doing indeed to avoid conversions. It seems thin/simple enough to maintain, and since it's private it's easy to experiment with.

@rkern
Copy link
Member

rkern commented Apr 23, 2024

Can we get a brief (representative, but nonexhaustive) catalog of the places in scipy where you want to use this? You mention that you need a handful of methods, but only one is implemented. I'd also like to see an example of what user code that uses one of the scipy functions that is implemented with this shim would look like, particularly with JAX. Shims like this have been investigated before and integration with JAX has some caveats and constraints. What exactly are the use cases that we are trying to enable?

That said, I wish that something like this were available somewhere if the array-API isn't going to standardize it (or until it does).

That's a thing that's not going to happen, for good reasons, so I would encourage you not to think of this effort as a standin for that. Those reasons would still apply here. Instead, if we can apply more constraints to address a more focused problem, we might be able to satisfy that problem instead, which is why I ask these questions to clarify the use cases.

Copy link
Contributor

@tylerjereddy tylerjereddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do find the simplicity of just coercing what we need from the NumPy implementation appealing, but don't have a sense of how much that would cost us in time vs. maintenance burden of growing the shims from this kind of design.

I ran on some GPU hardware locally and made a suggestion based on observed failures.


def random(self, shape=None, dtype=None):
shape = () if shape is None else shape
return self._xp.rand(shape, generator=self._rng, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing a few new test failures in the array API suite with these changes and torch on the GPU per SCIPY_DEVICE=cuda python dev.py test -j 32 -b all.

FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

A small patch seems to return us to baseline:

--- a/scipy/_lib/_array_api.py
+++ b/scipy/_lib/_array_api.py
@@ -393,7 +393,7 @@ class _Generator_jax(Generator):
 
 class _Generator_torch(Generator):
     def __init__(self, xp, seed=None):
-        rng = xp.Generator()
+        rng = xp.Generator(device=SCIPY_DEVICE)
         seed = rng.seed() if seed is None else seed
         rng.manual_seed(seed)
         self._rng = rng

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's why I suggested that we might want to add a device argument. This would be relevant for JAX, too.

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 24, 2024

What exactly are the use cases that we are trying to enable?

  1. For now, any place in tests decorated with @array_api_compatible that generate random numbers with NumPy before converting them to the desired array type. For example:
rng = numpy.random.default_rng(seed)
x = rng.random(shape, dtype=dtype)
x = xp.asarray(x)

would become

rng = Generator(xp, seed=seed)
x = rng.random(shape, dtype=dtype)

For specific examples, see e.g. tests in https://github.com/scipy/scipy/pull/19005/files. Making functions array API compatible is an active area, so many more tests are being converted like this.

  1. In the not-so-distant future, any other function in scipy that accepts a seed, random_state, or rng argument that we can otherwise make compatible with the array API might benefit from a way to generate samples on the target device. For example, the global optimizer differential_evolution accepts a seed argument so it can generate random data in each iteration of the algorithm, and it can perform many evaluations of the obective function in parallel. We might want to use a GPU to speed up evaluations of the objective functions, and in that case, it would be great to also generate the random data on the GPU rather than generating it with NumPy and copying it over each iteration.

If the only objective is to save device transfers, and if we only care about CPU and GPU, we could just use CuPy or NumPy Generators depending on the target device. I'd be fine with that instead of this PR (since NumPy and CuPy generators can be used pretty much interchangeably). But this would require that the user have CuPy if they're using a GPU, and of course this would leave TPUs out.

@rkern
Copy link
Member

rkern commented Apr 24, 2024

For generating test data in test cases, you should generate them using numpy and converting the arrays to the appropriate xp type. We generally want the same array data to be used as the test case, and you can't get that by using native PRNGs. We aren't performance sensitive here, so we don't gain much with xp-native PRNGs.

For (2), let's actually list them out. Are all of the distributions in scipy.stats in scope? Because that's probably going to take more than a "handful" of Generator methods. Let's figure out what each of these functions are actually consuming in terms of PRNG data. Some of them only draw relatively small arrays. What's the cutoff before the device transfer makes a significant difference? The most significant benefit to using a shimmed native JAX PRNG would be to enable JAX to work its JIT magic, but I wonder if that would even work for these functions setting aside the PRNG.

The changes in semantics for seeding are significant between numpy and a shimmed JAX, which will make one of our sore points for documentation even sorer.

It's hard to evaluate this implementation on its own. This PR should convert a function or two to see how it works out in practice.

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 24, 2024

This PR should convert a function or two to see how it works out in practice.

I put this out as-is to test the temperature of the water before diving in. It seems very hot in some places, so I might want to stay out!

I'll let others decide between we want to do these in tests. If not, then I don't think we need this right now; I'll close and we can reconsider as we get closer. Otherwise, I can add the other methods we're using and apply this to the existing test cases.

@mdhaber mdhaber marked this pull request as draft April 24, 2024 21:50
@rkern
Copy link
Member

rkern commented Apr 25, 2024

I put this out as-is to test the temperature of the water before diving in.

Well, the thermometer is an actual usage. I have no strong opinion about this implementation until I can see it in its intended use. It might be great! It might be solving the wrong problem. It might not be possible. It might cost too much. We can't predict which from just the bare unused implementation.

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 25, 2024

What's the cutoff before the device transfer makes a significant difference?

TLDR:
Overall, I don't think we'd be doing this for significant time savings in tests. It probably wouldn't help performance noticeably, but it doesn't look like it would hurt much either.

For other applications, like bootstrap, in which we need to generate a lot of random integers (say, 9999 resamples x data size), there would definitely be performance benefits to using the target backend. I'm not sure what fraction of the overall speedup this would account for, and I'm not ready to update the experiments in mdhaber#63 right now.

I'll add this to some tests if that would still be relevant, but I can't use it in any real functions right now. There is lower hanging fruit to pick in array API conversions before I get to those functions.


Details:
I ran some experiments under the assumptiont that this is for units tests that instantiate the Generator and use it once to generate data with the custom Generator.random. This sort of use is pretty common - to ensure that they get the same data every time, existing tests often instantiate a NumPy Generator or RandomState (or set the seed of the global RandomState) and use it only once at the beginning to generate data. I haven't read about whether the libraries use lazy techniques, so I tried both a) just using the random function and b) using the random function, summing the elements, and converting the result to a float.

For example, generating data with NumPy and transfering to CuPy, the relevant code looks like:

rng_np = Generator(np, 2349832)
x = cp.asarray(rng_np.random(size, dtype=cp.float64))
# result_np.append(float(cp.sum(x))) # where result_np is a list

and for generating the data directly with CuPy:

rng_cp = Generator(cp, 2349832)
x = rng_cp.random(size, dtype=cp.float64)
# result_cp.append(float(cp.sum(x)))

The conclusion were not affected by whether I included the commented line or not.

For CuPy, the cutoff at which CuPy begins to have a speed advantage is ~10^5 elements. The overhead of instantiating the CuPy generator is ~500μs vs 30μs for the NumPy Generator, so overall, this would slow things down slightly over the course of hundreds or thousands of tests.

Torch is a bit faster for up to 5e3 elements before NumPy catches up; past that, NumPys stays roughly constant amount faster. I think this would speed things up a little for Torch in tests, but the difference would almost certainly be too small to notice in a single run.

For JAX, the story is similar to CuPy under the same conditions. However, with JAX, the thing that takes the time is generating the initial key. We could potentially generate the key only only once per file and reuse it. Still, this would complicate things a bit, and it wouldn't save too much time.

@rkern
Copy link
Member

rkern commented Apr 25, 2024

Cool. For generating test data in unit tests, I think the desideratum for creating the same array data regardless of the xp array type matters more than other things, and it seems like there isn't something that we actually gain for this use case (the performance seems mostly a wash given the high frequency of Generator instantiation vs the relatively small array sizes drawn in the unit test suite).

For Generator-using functions, let's tackle the design when you're ready for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants