New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefer jnp.tile
over concatenate
#10221
Conversation
I think we can use |
@@ -580,7 +580,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape): | |||
|
|||
def _rbg_seed(seed: int) -> jnp.ndarray: | |||
halfkey = threefry_seed(seed) | |||
return jnp.concatenate([halfkey, halfkey]) | |||
return jnp.tile(halfkey, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
concatenate repeat axis 0, while tile repeat axis -1. Are you sure halfkey is 1d array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this seems to be the case:
Lines 246 to 256 in 3bfa6af
def threefry_seed(seed: int) -> jnp.ndarray: | |
"""Create a single raw threefry PRNG key given an integer seed. | |
Args: | |
seed: a 64- or 32-bit integer used as the value of the key. | |
Returns: | |
The PRNG key contents, modeled as an array of shape (2,) and dtype | |
uint32. The key is constructed from a 64-bit seed by effectively | |
bit-casting to a pair of uint32 values (or from a 32-bit seed by | |
first padding out with zeros). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will probably cause issues with custom PRNG keys, which don't concatenate like regular arrays. (I think jnp.concatenate
checks for this, while jnp.tile
does not). I'm going to request a review from @froystig, who would probably have recommendations here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lgeiger, just to make sure you noticed: the CI checks show that one unit test ( |
@froystig Sorry for the delay, but unfortunately I am unable to reproduce the test failure with Numpy 1.19.5 on macOS or even in a Python 3.7 Docker container. I am not sure what could cause this problem, since it doesn't seem to reproduce on all system. |
Thanks for taking a look and rebasing. Seems fine now... |
Prefer jnp.tile over concatenate. jnp.tile generates a jaxpr like the following: ``` { lambda ; a:i32[720192]. let b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a c:i32[720192] = squeeze[dimensions=(0,)] b d:i32[2,720192] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(2, 720192) ] c e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d in (e,) } ``` whereas lax.concatenate generates the following jaxpr: ``` { lambda ; a:i32[720192]. let b:i32[1440384] = concatenate[dimension=0] a a in (b,) } ``` It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible. Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change. PiperOrigin-RevId: 444260283
Prefer jnp.tile over concatenate. jnp.tile generates a jaxpr like the following: ``` { lambda ; a:i32[720192]. let b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a c:i32[720192] = squeeze[dimensions=(0,)] b d:i32[2,720192] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(2, 720192) ] c e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d in (e,) } ``` whereas lax.concatenate generates the following jaxpr: ``` { lambda ; a:i32[720192]. let b:i32[1440384] = concatenate[dimension=0] a a in (b,) } ``` It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible. Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change. PiperOrigin-RevId: 444260283
Prefer jnp.tile over concatenate. jnp.tile generates a jaxpr like the following: ``` { lambda ; a:i32[720192]. let b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a c:i32[720192] = squeeze[dimensions=(0,)] b d:i32[2,720192] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(2, 720192) ] c e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d in (e,) } ``` whereas lax.concatenate generates the following jaxpr: ``` { lambda ; a:i32[720192]. let b:i32[1440384] = concatenate[dimension=0] a a in (b,) } ``` It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible. Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change. PiperOrigin-RevId: 444287005
…0.3.8 Aart Bik (2): Adds a wrapper to sparse tensor dialect, as part of an Adds a ability to pass computation directly as module to backend Anudhyan Boral (1): Add unary xeinsum and allow named axis reductions for unary and binary xeinsums Carlos Martin (1): Added random.orthogonal. Dean Biskup (1): Add quotes to install commands in README Eugene Burmako (4): [MHLO] Add MHLO lowering for cosh [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO [MHLO] Switch tan to use CHLO lowering [MHLO] Add MHLO lowering for erf and erfc Gain Hagenau (1): Remove flags set for all v4 TPUs. Topology flags will now be set in libTPU. Jake VanderPlas (23): DOC: add section on input validation in custom pytrees update mypy & related package versions [sparse] make bcoo_sort_indices a primitive jnp.take: add documentation for mode parameter default [sparse] fix bug in bcoo_sort_indices batching rule [sparse] make bcoo_sum_duplicates a primitive Add comment explaining implementation in promote_types Skip normalization of unsigned indices Deprecate remaining functionality in jax.test_util CHANGELOG: update test_util deprecation discussion CI: fix flake8 ignore declarations Apply flake8 checks to xmap_test.py jnp.unwrap: add support for period argument DOC: link to install instructions in HTML docs Add missing __init__ file in jax._src.scipy.cluster [sparse] improve error messages for unimplemented primitives lax.linalg.qr: allow jvp when m == n and full_matrices=True jax.scipy.qr: fix return type for mode='r' [sparse] bcoo_broadcast_in_dim: default to adding leading batch dimensions [sparse] implement sparse rule for lax.concatenate_p Make lax.linalg.qr robust to zero-dimensional inputs dtypes.result_type: add optional return_weak_type_flag argument docs: pin myst-nb to 0.13.2 James Bradbury (2): [mesh_utils] Support creating device meshes for hybrid networks [mesh_utils] Add device/slice count checks Jean-Baptiste (1): Improve the random module documentation. Jeppe Klitgaard (1): fix: explicit reexport Kehang Han (1): updates maps.mesh() context manager in jax_101/pjit documentation Lena Martens (2): Checkify: support checks on data-independent values. Apply suggestions from code review Lukas Geiger (2): Prefer `jnp.tile` over `concatenate` Use `lax.iota` also for `jnp.arange(0, stop)` Matthew Johnson (34): fix issue #10366 remove lax._device_put_raw remove old cond todos add partial_eval_jaxpr_nounits add docstring [remove-units] prevent cond partial eval from introducing units refine const folding [remove-units] remove units from while partial eval [remove-units] prevent scan partial eval from introducing units [remove-units] remove units from api_util.py [remove-units] prevent ad.py from introducing units fix jax2tf test Copybara import of the project: [remove-units] roll forward #10448, fix dce bug don't bind scan on jaxpr_known if no outputs Weaken some newly-added assertions, which are catching some weak type [remove-units] prevent remat's partial eval from introducing units [remove-units] prevent pjit partial eval from dealing with units add scan dce rule fix redundant (harmless) axis env extension in pmap partial eval [remove-units] avoid making xmap partial eval deal with units [remove-units] don't use abstract_unit for dropvar avals reviewer comments broken remat test! fix scan dce rule disable for now skip tests [remove-units] remove partial_eval_jaxpr (no callers!) [remove-units] remove units from custom_jvp/vjp [remove-units] avoid units in new remat [remove-units] remove now-dead flax helper function [remove-units] avoid unit-generating function in lax.reduce [remove-units] avoid unit-generating function in jax.linear_transpose Trivial change for backward compatibility stub. Parker Schuh (1): Switch gpu, cpu and jax2tf to use the new OptimizationBarrier op. Peter Hawkins (27): Add cross-reference to lax.GatherScatterMode from jax.numpy.ndarray.at documentation. [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule. Remove CpuDevice from documentation index. Copybara import of the project: Increase minimum jaxlib version to 0.3.7. Fix typo in changelog. [MHLO] Remove most XLA translation rules. [JAX] Remove xla.call_translations from JAX. [JAX] Delete last references to conv/dot translation rules. [XLA] Call translation rule directly in xla.primitive_subcomputation. [MHLO] Switch call_tf to use an MHLO lowering. Change default jnp.take_along_axis gather mode to "fill". Prevent negative output shapes in shape inference for reduce_window. Temporarily revert: Change default jnp.take_along_axis gather mode to "fill". Add an optional mode= argument to jnp.take_along_axis. [MHLO] Switch call_tf to use an MHLO lowering (attempt 2). Use lax.broadcasted_iota in jax.nn.one_hot. Change jnp.take_along_axis to require that its indices are of integer type. [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices. Revert: google/jax#10221 (2nd revert) Change the default scatter mode to FILL_OR_DROP. Change the default jnp.take mode to "fill". Reexport jaxlib.__version as jax.lib.__version__. Lock down the default Bazel visibility of //jaxlib targets. Drop Bazel visibility of //jaxlib. Allow sharded_jit on CPU. [JAX] Validate that platforms passed to MHLO lowering are known to exist. Rohit Santhanam (1): ROCm specific fixes. Roy Frostig (1): always lower/compile computations on the AOT jit path Sergei Lebedev (1): compile_or_get_cached() is now more type-checker friendly Sharad Vikram (4): Fix name stack bugs add in mlir lowering for tokens Add CustomCall MLIR lowering for HCB `outside_call` primitive Add an `emit_python_callback` helper function Tianjian Lu (7): [linalg] Adds `compute_uv` to TPU SVD. [linalg] Adds `full_matrices` option to TPU SVD. [sparse] Change the outer call to ir.RankedTensorType to make it call .get(). [sparse] Add BCOO attribute `_indices_sorted`. [linalg] Update svd test on reconstructed operands and unitary singular vectors. [signal] Update signal detrend test. [linalg] Add tpu svd lowering rule. Xin Zhou (1): [mhlo] Add result type inference for mhlo.broadcast. Yash Katariya (5): Add another example of using `Mesh`. Delete the `mesh` context manager. The replacement for it is `Mesh`. Update the changelog to say sharded_jit is deprecated. Add a dtypes option to cast host arrays when reloading from TS. Raise a better error when assert fails in mesh_sharding_specs YouJiacheng (8): simplify _IndexGrid DOC: lax.linalg.eigh remove numpy.linalg._promote_arg_dtypes Fix typo of #10381 Fix typo in _scatter_add_lower_gpu replace int with operator.index part2 implement scipy.cluster.vq.vq try to improve docs for scipy.linalg with unused parameters jax authors (3): Temporarily revert fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>: Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>: [MHLO] Switch call_tf to use an MHLO lowering. lipracer (1): Compatible with RngBitGeneratorOp builder modifications yashkatariya (4): Add a docstring for maps.Mesh fix build fix build fix build
jnp.tile
only uses broadcasting which tends to be preferred so this PR replacesconcatenate
withjnp.tile
in the cases where it is possible.