New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support reshape/concatenate/broadcast on PRNGKeyArrays. #8381
Support reshape/concatenate/broadcast on PRNGKeyArrays. #8381
Conversation
c9ed671
to
4b275a1
Compare
Updated with fixes for mypy/docstring testing, PTAL :) Even though there is still the open question of what set of operations to finally support, I think this is an OK incremental change by itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! The code is nice and clean. But even better is the PR message! That taxonomy is a fantastic leap forward for the broader effort here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattjj said it well. Really nice work on this, and thanks for laying things out so clearly in the PR description!
Just a note: if It may not be worth building in a new high-level dtype abstraction into JAX to handle this at the primitive level, but for what it's worth this is exactly the sort of thing we would need to properly support strings. |
Fixes google#1012 The idea here is a simpler alternative to google#3263, based on a custom pytree (`tree_math.Vector`) rather than introducing a new final-style transformation. Eventually, it should be possible to write something like: ```python @functools.partial(tree_math.wrap, vector_argnums=(1,)) def odeint_midpoint(f, y0, h, steps): F = tree_math.unwrap(f, vector_argnums=(0,)) def step_fun(i, y): return y + h * F(y + F(y, i * h) / 2, (i + 0.5) * h)) return lax.fori_loop(0, num_steps, bodstep_funy_fun, y0) ``` Aside from `wrap` and `unwrap`, this is exactly how you would write a simple ODE solver in JAX without support for PyTrees. We currently [do something very similar](https://github.com/google/jax-cfd/blob/a92862fb757d122e7c5eee23a3b783997a2aeafe/jax_cfd/spectral/time_stepping.py#L92) for implementing ODE solvers in JAX-CFD. The upside of this approach is that writing a custom pytree is easy and entirely decoupled from JAX's other transformations. So we get support for JAX transformations (e.g., `grad`/`jit`/`vmap`) and control flow primitives (e.g., `while_loop`) for free. The downside is that it won't be possible to use standard `jax.numpy` functions on `Vector` objects, unless we make `jax.numpy` aware of `tree_math` (ala google#8381). Instead, I've added a few specialized helper functions like `where` and `maximum`. For now, this seems like an acceptable trade-off, given that the core use-case for tree math is to make it easy to write new algorithms from scientific computing, for which infix arithmetic is important enough that it should easily justify minor deviations from standard `jax.numpy`. The other change from google#3263 is that I've only written a simple "Vector" class for now rather than making an attempt to support arbitrary higher dimensional arrays. This also considerably simplifies the implementation. In the future, we can imagine adding a "Matrix" class, for methods that need to keep track of multiple vectors (e.g., ODE solvers, GMRES, L-BFGS, Lanczos, etc).
4b275a1
to
e14fea3
Compare
Thanks for the review! (and sorry for the delay) I think there is still an open question about which set of operations to eventually support, but we can do this in a follow-up. |
Add a guard to the nan_error_rule to not call jnp.isnan on keys. The primitive with a key in the output was slice_p, which I don't think can ever generate a NaN? I removed a few primitives from the NaN rules which I think only index or reshape, PTAL if removing those makes sense. In general, the nan_checked primitive set could use another review. In fact, I think these kinds of "value polymorphic" primitives are the only ones we allow for keys right now (see google#8381), so if we remove those we could remove the guard entirely.
PRNGKeyArrays
(which are currently behind theenable_custom_prng
flag) do not support anyjnp
operations. This makes sense for operations which act on the value of a key (we shouldn't support operations likekey0 + key1
), but we want to keep supporting operations which are agnostic about the value of a key (eg. abroadcast
).This PR adds a way to overload
jnp.reshape
,jnp.concatenate
andjnp.broadcast
, then overloads them forPRNGKeyArrays
. The claim is that overloading these three operations is sufficient to support mostjnp
functions users care about.Given we support
reshape
,concatenate
andbroadcast
, there are somejnp
functions we get for free (eg.tile
), but there are some which are implemented in terms oflax
primitives likelax.expand_dims
instead. The motivation for this is described in #3217 (preserve the notion of axis identity for reshapes which add/remove singleton dimensions)I see two options to support these functions:
expand_dims
andsqueeze
forPRNGKeyArrays
jnp.expand_dims/squeeze
to usereshape
/broadcast
insteadHappy to discuss what operations should be supported here, and how best to go about it :)
For completeness, here's an enumeration of
jnp
functions which we could support for PRNGKeyArrays, and what needs to be overloaded to support them.I see 5 categories of
jnp
functions, in terms of the base ops they are implemented with:append
)stack
)delete
)jnp.array(PRNGKeyArray)
?)(Incomplete) list of
jnp
operations which we could support for PRNGKeyArrays and their categories (see above):