Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer jnp.tile over concatenate #10221

Merged
merged 1 commit into from Apr 18, 2022
Merged

Conversation

lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Apr 11, 2022

jnp.tile only uses broadcasting which tends to be preferred so this PR replaces concatenate with jnp.tile in the cases where it is possible.

@YouJiacheng
Copy link
Contributor

I think we can use jax.lax.broadcast_in_dim + jax.lax.reshape for faster tracing time. There are lots of python code in jnp.tile and jnp.broadcast_to (used by jnp.tile).

@@ -580,7 +580,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):

def _rbg_seed(seed: int) -> jnp.ndarray:
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])
return jnp.tile(halfkey, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concatenate repeat axis 0, while tile repeat axis -1. Are you sure halfkey is 1d array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this seems to be the case:

jax/jax/_src/prng.py

Lines 246 to 256 in 3bfa6af

def threefry_seed(seed: int) -> jnp.ndarray:
"""Create a single raw threefry PRNG key given an integer seed.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
The PRNG key contents, modeled as an array of shape (2,) and dtype
uint32. The key is constructed from a 64-bit seed by effectively
bit-casting to a pair of uint32 values (or from a 32-bit seed by
first padding out with zeros).

Copy link
Collaborator

@jakevdp jakevdp Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will probably cause issues with custom PRNG keys, which don't concatenate like regular arrays. (I think jnp.concatenate checks for this, while jnp.tile does not). I'm going to request a review from @froystig, who would probably have recommendations here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thinking @jakevdp. Actually, in this case, tile is valid use and is already supported! For more on how and why see #8381.

@jakevdp jakevdp requested a review from froystig April 11, 2022 21:11
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 12, 2022
@froystig
Copy link
Member

@lgeiger, just to make sure you noticed: the CI checks show that one unit test (LaxBackedNumpyTests.testNonScalarRepeats_fixed_size=False) fails on python 3.7 with our oldest supported version of numpy (1.19.5).

@froystig froystig self-assigned this Apr 17, 2022
@lgeiger
Copy link
Contributor Author

lgeiger commented Apr 18, 2022

just to make sure you noticed: the CI checks show that one unit test (LaxBackedNumpyTests.testNonScalarRepeats_fixed_size=False) fails on python 3.7 with our oldest supported version of numpy (1.19.5)

@froystig Sorry for the delay, but unfortunately I am unable to reproduce the test failure with Numpy 1.19.5 on macOS or even in a Python 3.7 Docker container. I am not sure what could cause this problem, since it doesn't seem to reproduce on all system.
For now I rebased my PR onto the latest main, maybe this will fix it? But this could also be indicative of some underlying problem that is platform specific.

@froystig
Copy link
Member

Thanks for taking a look and rebasing. Seems fine now...

@copybara-service copybara-service bot merged commit f6705fc into google:main Apr 18, 2022
@lgeiger lgeiger deleted the concat-tile branch April 18, 2022 19:26
copybara-service bot pushed a commit that referenced this pull request Apr 25, 2022
Prefer jnp.tile over concatenate.

jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
    b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
    c:i32[720192] = squeeze[dimensions=(0,)] b
    d:i32[2,720192] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(2, 720192)
    ] c
    e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
  in (e,) }
```

whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
    b:i32[1440384] = concatenate[dimension=0] a a
  in (b,) }
```

It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.

Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.

PiperOrigin-RevId: 444260283
copybara-service bot pushed a commit that referenced this pull request Apr 25, 2022
Prefer jnp.tile over concatenate.

jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
    b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
    c:i32[720192] = squeeze[dimensions=(0,)] b
    d:i32[2,720192] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(2, 720192)
    ] c
    e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
  in (e,) }
```

whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
    b:i32[1440384] = concatenate[dimension=0] a a
  in (b,) }
```

It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.

Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.

PiperOrigin-RevId: 444260283
copybara-service bot pushed a commit that referenced this pull request Apr 25, 2022
Prefer jnp.tile over concatenate.

jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
    b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
    c:i32[720192] = squeeze[dimensions=(0,)] b
    d:i32[2,720192] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(2, 720192)
    ] c
    e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
  in (e,) }
```

whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
    b:i32[1440384] = concatenate[dimension=0] a a
  in (b,) }
```

It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.

Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.

PiperOrigin-RevId: 444287005
clrpackages pushed a commit to clearlinux-pkgs/pypi-jax that referenced this pull request May 3, 2022
…0.3.8

Aart Bik (2):
      Adds a wrapper to sparse tensor dialect, as part of an
      Adds a ability to pass computation directly as module to backend

Anudhyan Boral (1):
      Add unary xeinsum and allow named axis reductions for unary and binary xeinsums

Carlos Martin (1):
      Added random.orthogonal.

Dean Biskup (1):
      Add quotes to install commands in README

Eugene Burmako (4):
      [MHLO] Add MHLO lowering for cosh
      [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
      [MHLO] Switch tan to use CHLO lowering
      [MHLO] Add MHLO lowering for erf and erfc

Gain Hagenau (1):
      Remove flags set for all v4 TPUs. Topology flags will now be set in libTPU.

Jake VanderPlas (23):
      DOC: add section on input validation in custom pytrees
      update mypy & related package versions
      [sparse] make bcoo_sort_indices a primitive
      jnp.take: add documentation for mode parameter default
      [sparse] fix bug in bcoo_sort_indices batching rule
      [sparse] make bcoo_sum_duplicates a primitive
      Add comment explaining implementation in promote_types
      Skip normalization of unsigned indices
      Deprecate remaining functionality in jax.test_util
      CHANGELOG: update test_util deprecation discussion
      CI: fix flake8 ignore declarations
      Apply flake8 checks to xmap_test.py
      jnp.unwrap: add support for period argument
      DOC: link to install instructions in HTML docs
      Add missing __init__ file in jax._src.scipy.cluster
      [sparse] improve error messages for unimplemented primitives
      lax.linalg.qr: allow jvp when m == n and full_matrices=True
      jax.scipy.qr: fix return type for mode='r'
      [sparse] bcoo_broadcast_in_dim: default to adding leading batch dimensions
      [sparse] implement sparse rule for lax.concatenate_p
      Make lax.linalg.qr robust to zero-dimensional inputs
      dtypes.result_type: add optional return_weak_type_flag argument
      docs: pin myst-nb to 0.13.2

James Bradbury (2):
      [mesh_utils] Support creating device meshes for hybrid networks
      [mesh_utils] Add device/slice count checks

Jean-Baptiste (1):
      Improve the random module documentation.

Jeppe Klitgaard (1):
      fix: explicit reexport

Kehang Han (1):
      updates maps.mesh() context manager in jax_101/pjit documentation

Lena Martens (2):
      Checkify: support checks on data-independent values.
      Apply suggestions from code review

Lukas Geiger (2):
      Prefer `jnp.tile` over `concatenate`
      Use `lax.iota` also for `jnp.arange(0, stop)`

Matthew Johnson (34):
      fix issue #10366
      remove lax._device_put_raw
      remove old cond todos
      add partial_eval_jaxpr_nounits
      add docstring
      [remove-units] prevent cond partial eval from introducing units
      refine const folding
      [remove-units] remove units from while partial eval
      [remove-units] prevent scan partial eval from introducing units
      [remove-units] remove units from api_util.py
      [remove-units] prevent ad.py from introducing units
      fix jax2tf test
      Copybara import of the project:
      [remove-units] roll forward #10448, fix dce bug
      don't bind scan on jaxpr_known if no outputs
      Weaken some newly-added assertions, which are catching some weak type
      [remove-units] prevent remat's partial eval from introducing units
      [remove-units] prevent pjit partial eval from dealing with units
      add scan dce rule
      fix redundant (harmless) axis env extension in pmap partial eval
      [remove-units] avoid making xmap partial eval deal with units
      [remove-units] don't use abstract_unit for dropvar avals
      reviewer comments
      broken remat test!
      fix scan dce rule
      disable for now
      skip tests
      [remove-units] remove partial_eval_jaxpr (no callers!)
      [remove-units] remove units from custom_jvp/vjp
      [remove-units] avoid units in new remat
      [remove-units] remove now-dead flax helper function
      [remove-units] avoid unit-generating function in lax.reduce
      [remove-units] avoid unit-generating function in jax.linear_transpose
      Trivial change for backward compatibility stub.

Parker Schuh (1):
      Switch gpu, cpu and jax2tf to use the new OptimizationBarrier op.

Peter Hawkins (27):
      Add cross-reference to lax.GatherScatterMode from jax.numpy.ndarray.at documentation.
      [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
      Remove CpuDevice from documentation index.
      Copybara import of the project:
      Increase minimum jaxlib version to 0.3.7.
      Fix typo in changelog.
      [MHLO] Remove most XLA translation rules.
      [JAX] Remove xla.call_translations from JAX.
      [JAX] Delete last references to conv/dot translation rules.
      [XLA] Call translation rule directly in xla.primitive_subcomputation.
      [MHLO] Switch call_tf to use an MHLO lowering.
      Change default jnp.take_along_axis gather mode to "fill".
      Prevent negative output shapes in shape inference for reduce_window.
      Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
      Add an optional mode= argument to jnp.take_along_axis.
      [MHLO] Switch call_tf to use an MHLO lowering (attempt 2).
      Use lax.broadcasted_iota in jax.nn.one_hot.
      Change jnp.take_along_axis to require that its indices are of integer type.
      [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
      Revert: google/jax#10221 (2nd revert)
      Change the default scatter mode to FILL_OR_DROP.
      Change the default jnp.take mode to "fill".
      Reexport jaxlib.__version as jax.lib.__version__.
      Lock down the default Bazel visibility of //jaxlib targets.
      Drop Bazel visibility of //jaxlib.
      Allow sharded_jit on CPU.
      [JAX] Validate that platforms passed to MHLO lowering are known to exist.

Rohit Santhanam (1):
      ROCm specific fixes.

Roy Frostig (1):
      always lower/compile computations on the AOT jit path

Sergei Lebedev (1):
      compile_or_get_cached() is now more type-checker friendly

Sharad Vikram (4):
      Fix name stack bugs
      add in mlir lowering for tokens
      Add CustomCall MLIR lowering for HCB `outside_call` primitive
      Add an `emit_python_callback` helper function

Tianjian Lu (7):
      [linalg] Adds `compute_uv` to TPU SVD.
      [linalg] Adds `full_matrices` option to TPU SVD.
      [sparse] Change the outer call to ir.RankedTensorType to make it call .get().
      [sparse] Add BCOO attribute `_indices_sorted`.
      [linalg] Update svd test on reconstructed operands and unitary singular vectors.
      [signal] Update signal detrend test.
      [linalg] Add tpu svd lowering rule.

Xin Zhou (1):
      [mhlo] Add result type inference for mhlo.broadcast.

Yash Katariya (5):
      Add another example of using `Mesh`.
      Delete the `mesh` context manager. The replacement for it is `Mesh`.
      Update the changelog to say sharded_jit is deprecated.
      Add a dtypes option to cast host arrays when reloading from TS.
      Raise a better error when assert fails in mesh_sharding_specs

YouJiacheng (8):
      simplify _IndexGrid
      DOC: lax.linalg.eigh
      remove numpy.linalg._promote_arg_dtypes
      Fix typo of #10381
      Fix typo in _scatter_add_lower_gpu
      replace int with operator.index part2
      implement scipy.cluster.vq.vq
      try to improve docs for scipy.linalg with unused parameters

jax authors (3):
      Temporarily revert fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
      Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
      [MHLO] Switch call_tf to use an MHLO lowering.

lipracer (1):
      Compatible with RngBitGeneratorOp builder modifications

yashkatariya (4):
      Add a docstring for maps.Mesh
      fix build
      fix build
      fix build
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants