Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reshape/concatenate/broadcast on PRNGKeyArrays. #8381

Merged
merged 1 commit into from Nov 16, 2021

Conversation

LenaMartens
Copy link
Contributor

@LenaMartens LenaMartens commented Oct 27, 2021

PRNGKeyArrays (which are currently behind the enable_custom_prng flag) do not support any jnp operations. This makes sense for operations which act on the value of a key (we shouldn't support operations like key0 + key1), but we want to keep supporting operations which are agnostic about the value of a key (eg. a broadcast).

This PR adds a way to overload jnp.reshape, jnp.concatenate and jnp.broadcast, then overloads them for PRNGKeyArrays. The claim is that overloading these three operations is sufficient to support most jnp functions users care about.

Given we support reshape, concatenate and broadcast, there are some jnp functions we get for free (eg. tile), but there are some which are implemented in terms of lax primitives like lax.expand_dims instead. The motivation for this is described in #3217 (preserve the notion of axis identity for reshapes which add/remove singleton dimensions)

I see two options to support these functions:

  • overload expand_dims and squeeze for PRNGKeyArrays
  • rewrite jnp.expand_dims/squeeze to use reshape/broadcast instead

Happy to discuss what operations should be supported here, and how best to go about it :)


For completeness, here's an enumeration of jnp functions which we could support for PRNGKeyArrays, and what needs to be overloaded to support them.

I see 5 categories of jnp functions, in terms of the base ops they are implemented with:

  1. operations which are implemented in terms of reshape/concatenate/broadcast, and we get for free (eg. append)
  2. operations which use expand_dims/squeeze, but could be implemented with reshapes (eg. stack)
  3. operations which use indexing (which we probably don't want to support?) (eg. delete)
  4. operations which create/cast to arrays (should we support jnp.array(PRNGKeyArray)?)
  5. transposes (shuffling of elements)

(Incomplete) list of jnp operations which we could support for PRNGKeyArrays and their categories (see above):

jnp op implemented with.. category
atleast_1d reshape+array 4
atleast_2d expand_dims+array 4
append(arr, values[, axis]) concatenate 1
array(object[, dtype, copy, order, ndmin, …]) array 4
array_split(ary, indices_or_sections[, axis]) split=slice 3
asarray(a[, dtype, order]) array 4
broadcast_arrays(*args) broadcast_to 1
broadcast_to(arr, shape) broadcast 1
choose(a, choices[, out, mode]) broadcast+index 3
compress(condition, a[, axis, out]) index 3
concatenate(arrays[, axis]) concatenate 1
delete(arr, obj[, axis]) ravel+index 3
empty_like(a[, dtype, shape]) zeros_like 4
expand_dims(a, axis) expand_dims (could be broadcast/reshape?) 2
flip(m[, axis]) lax.rev 5
fliplr(m) flip=rev 5
flipud(m) flip=rev 5
hsplit(ary, indices_or_sections) split=slice 3
hstack(tup) concatenate+atleast_1dim 2
insert(arr, obj, values[, axis]) indexing 3
moveaxis(a, source, destination) lax.transpose 5
ones_like(a[, dtype, shape]) lax.full_like 4
ravel(a[, order]) reshape 1
repeat(a, repeats[, axis, total_repeat_length]) take 3
reshape(a, newshape[, order]) reshape 1
resize(a, new_shape) reshape+tile+indexing 3
roll(a, shift[, axis]) indexing 3
rollaxis(a, axis[, start]) move_axis 5
row_stack(tup) vstack 2
split(ary, indices_or_sections[, axis]) lax.slice 3
squeeze(a[, axis]) lax.squeeze 2
stack(arrays[, axis, out]) concatenate+expand_dims 2
swapaxes(a, axis1, axis2) lax.transpose 5
take(a, indices[, axis, out, mode]) index 3
take_along_axis(arr, indices, axis) index 3
transpose(a[, axes]) lax.transpose 5
tile broadcast_to+reshape 1
vsplit(ary, indices_or_sections) split 3
vstack(tup) concatenate+expand_dims 2
zeros_like(a[, dtype, shape]) lax.full_like 4

@LenaMartens
Copy link
Contributor Author

Updated with fixes for mypy/docstring testing, PTAL :)

Even though there is still the open question of what set of operations to finally support, I think this is an OK incremental change by itself.

Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! The code is nice and clean. But even better is the PR message! That taxonomy is a fantastic leap forward for the broader effort here.

jax/numpy/__init__.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 8, 2021
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattjj said it well. Really nice work on this, and thanks for laying things out so clearly in the PR description!

@shoyer
Copy link
Member

shoyer commented Nov 10, 2021

Just a note: if PRNGKeyArray was in NumPy, we could imagine implementing it as an array with a new data type, i.e., with data stored in a vector of two elements like np.dtype('u4,u4').

It may not be worth building in a new high-level dtype abstraction into JAX to handle this at the primitive level, but for what it's worth this is exactly the sort of thing we would need to properly support strings.

shoyer added a commit to shoyer/jax that referenced this pull request Nov 10, 2021
Fixes google#1012

The idea here is a simpler alternative to google#3263, based on a custom
pytree (`tree_math.Vector`) rather than introducing a new final-style
transformation.

Eventually, it should be possible to write something like:
```python
@functools.partial(tree_math.wrap, vector_argnums=(1,))
def odeint_midpoint(f, y0, h, steps):
    F = tree_math.unwrap(f, vector_argnums=(0,))
    def step_fun(i, y):
        return y + h * F(y + F(y, i * h) / 2, (i + 0.5) * h))
    return lax.fori_loop(0, num_steps, bodstep_funy_fun, y0)
```
Aside from `wrap` and `unwrap`, this is exactly how you would write a
simple ODE solver in JAX without support for PyTrees. We currently
[do something very similar](https://github.com/google/jax-cfd/blob/a92862fb757d122e7c5eee23a3b783997a2aeafe/jax_cfd/spectral/time_stepping.py#L92)
for implementing ODE solvers in JAX-CFD.

The upside of this approach is that writing a custom pytree is easy and
entirely decoupled from JAX's other transformations. So we get support
for JAX transformations (e.g., `grad`/`jit`/`vmap`) and control flow
primitives (e.g., `while_loop`) for free.

The downside is that it won't be possible to use standard `jax.numpy`
functions on `Vector` objects, unless we make `jax.numpy` aware of
`tree_math` (ala google#8381). Instead, I've added a few specialized helper
functions like `where` and `maximum`. For now, this seems like an
acceptable trade-off, given that the core use-case for tree math is to
make it easy to write new algorithms from scientific computing, for
which infix arithmetic is important enough that it should easily justify
minor deviations from standard `jax.numpy`.

The other change from google#3263 is that I've only written a simple "Vector"
class for now rather than making an attempt to support arbitrary higher
dimensional arrays. This also considerably simplifies the
implementation. In the future, we can imagine adding a "Matrix" class,
for methods that need to keep track of multiple vectors (e.g., ODE
solvers, GMRES, L-BFGS, Lanczos, etc).
@LenaMartens
Copy link
Contributor Author

Thanks for the review! (and sorry for the delay)

I think there is still an open question about which set of operations to eventually support, but we can do this in a follow-up.

@copybara-service copybara-service bot merged commit 9e09b51 into google:main Nov 16, 2021
LenaMartens added a commit to LenaMartens/jax that referenced this pull request Nov 9, 2022
Add a guard to the nan_error_rule to not call jnp.isnan on keys.

The primitive with a key in the output was slice_p, which I don't think
can ever generate a NaN? I removed a few primitives from the NaN rules
which I think only index or reshape, PTAL if removing those makes sense.
In general, the nan_checked primitive set could use another review.

In fact, I think these kinds of "value polymorphic" primitives are the
only ones we allow for keys right now (see google#8381), so if we remove those
we could remove the guard entirely.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants