New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: _array_api.Generator: unified RNG interface #20549
base: main
Are you sure you want to change the base?
Conversation
# elif is_jax(xp): | ||
# return super().__new__(_Generator_jax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uncomment after gh-20085 merges.
If it's just to internally have some kind of pseudorandomness that outputs the |
Sure, that is what we do right now for tests, so I considered it. However, the resampling functions need only a handful of RNG methods, and we're trying to avoid conversions between array formats and especially device transfers (whether they appear to be the bottleneck or not, it seems), so I thought it would be worth trying this. Fine with me if there is a strong consensus for always using NumPy generators and converting as necessary. That said, I wish that something like this were available somewhere if the array-API isn't going to standardize it (or until it does). |
either way seems fine for now - definitely nice to have this work in case we want/need it in the future |
Please add a note that the random streams may differ for different backends, even for the same seed. |
I think this is worth doing indeed to avoid conversions. It seems thin/simple enough to maintain, and since it's private it's easy to experiment with. |
Can we get a brief (representative, but nonexhaustive) catalog of the places in scipy where you want to use this? You mention that you need a handful of methods, but only one is implemented. I'd also like to see an example of what user code that uses one of the scipy functions that is implemented with this shim would look like, particularly with JAX. Shims like this have been investigated before and integration with JAX has some caveats and constraints. What exactly are the use cases that we are trying to enable?
That's a thing that's not going to happen, for good reasons, so I would encourage you not to think of this effort as a standin for that. Those reasons would still apply here. Instead, if we can apply more constraints to address a more focused problem, we might be able to satisfy that problem instead, which is why I ask these questions to clarify the use cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do find the simplicity of just coercing what we need from the NumPy implementation appealing, but don't have a sense of how much that would cost us in time vs. maintenance burden of growing the shims from this kind of design.
I ran on some GPU hardware locally and made a suggestion based on observed failures.
|
||
def random(self, shape=None, dtype=None): | ||
shape = () if shape is None else shape | ||
return self._xp.rand(shape, generator=self._rng, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing a few new test failures in the array API suite with these changes and torch
on the GPU per SCIPY_DEVICE=cuda python dev.py test -j 32 -b all
.
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float32-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-None-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-None-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
FAILED scipy/_lib/tests/test_array_api.py::TestArrayAPI::test_Generator[float64-shape1-0-torch] - RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
A small patch seems to return us to baseline:
--- a/scipy/_lib/_array_api.py
+++ b/scipy/_lib/_array_api.py
@@ -393,7 +393,7 @@ class _Generator_jax(Generator):
class _Generator_torch(Generator):
def __init__(self, xp, seed=None):
- rng = xp.Generator()
+ rng = xp.Generator(device=SCIPY_DEVICE)
seed = rng.seed() if seed is None else seed
rng.manual_seed(seed)
self._rng = rng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's why I suggested that we might want to add a device
argument. This would be relevant for JAX, too.
rng = numpy.random.default_rng(seed)
x = rng.random(shape, dtype=dtype)
x = xp.asarray(x) would become rng = Generator(xp, seed=seed)
x = rng.random(shape, dtype=dtype) For specific examples, see e.g. tests in https://github.com/scipy/scipy/pull/19005/files. Making functions array API compatible is an active area, so many more tests are being converted like this.
If the only objective is to save device transfers, and if we only care about CPU and GPU, we could just use CuPy or NumPy |
For generating test data in test cases, you should generate them using For (2), let's actually list them out. Are all of the distributions in The changes in semantics for seeding are significant between numpy and a shimmed JAX, which will make one of our sore points for documentation even sorer. It's hard to evaluate this implementation on its own. This PR should convert a function or two to see how it works out in practice. |
I put this out as-is to test the temperature of the water before diving in. It seems very hot in some places, so I might want to stay out! I'll let others decide between we want to do these in tests. If not, then I don't think we need this right now; I'll close and we can reconsider as we get closer. Otherwise, I can add the other methods we're using and apply this to the existing test cases. |
Well, the thermometer is an actual usage. I have no strong opinion about this implementation until I can see it in its intended use. It might be great! It might be solving the wrong problem. It might not be possible. It might cost too much. We can't predict which from just the bare unused implementation. |
TLDR: For other applications, like I'll add this to some tests if that would still be relevant, but I can't use it in any real functions right now. There is lower hanging fruit to pick in array API conversions before I get to those functions. Details: For example, generating data with NumPy and transfering to CuPy, the relevant code looks like: rng_np = Generator(np, 2349832)
x = cp.asarray(rng_np.random(size, dtype=cp.float64))
# result_np.append(float(cp.sum(x))) # where result_np is a list and for generating the data directly with CuPy: rng_cp = Generator(cp, 2349832)
x = rng_cp.random(size, dtype=cp.float64)
# result_cp.append(float(cp.sum(x))) The conclusion were not affected by whether I included the commented line or not. For CuPy, the cutoff at which CuPy begins to have a speed advantage is ~10^5 elements. The overhead of instantiating the CuPy generator is ~500μs vs 30μs for the NumPy Generator, so overall, this would slow things down slightly over the course of hundreds or thousands of tests. Torch is a bit faster for up to For JAX, the story is similar to CuPy under the same conditions. However, with JAX, the thing that takes the time is generating the initial key. We could potentially generate the key only only once per file and reuse it. Still, this would complicate things a bit, and it wouldn't save too much time. |
Cool. For generating test data in unit tests, I think the desideratum for creating the same array data regardless of the For |
Reference issue
gh-18286
What does this implement/fix?
There is no
random
extension for the array API, but we need to generate random numbers of each array type in tests and in some stats functions. The PR drafts a unfied interface to random number generation for supported backends. I've tested end-to-end with NumPy and PyTorch locally, and I've tested the subclasses for CuPy and JAX on Colab (but it would be great if somebody could run the test suite for them).Additional information
I'm not sure if this is what we want, but I figured the best way to come to a consensus was to put something concrete out there.
I considered calling it
xp_Generator
... not sure how to balance the desire to putxp
in front of it against wanting a class name to start with a capital letter.Just as
special
provides a unified interface to some functions before we have a special function extension of the array API, it would be useful to have a unified interface to RNGs, even if the stateful generator object goes against the paradigm of some libraries. Moreover, we will need a way for users to specify the RNG to use in functions likescipy.stats.bootstrap
. This should be private initially, but it's worth considering making it public at some point.Example:
Of course, we can add additional methods and options (e.g.
device
, a way to instantiate from an existing backend-specific generator object, etc.) as needed; this would just lay the groundwork.