Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634019673
  • Loading branch information
vanderplas authored and tensorflower-gardener committed May 15, 2024
1 parent 3e65280 commit 9a84874
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 32 deletions.
16 changes: 8 additions & 8 deletions discussion/adaptive_malt/adaptive_malt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def meads_step(meads_state: MeadsState,
def refold(x, perm):
return x.reshape((num_chains,) + x.shape[2:])[perm].reshape(x.shape)

phmc_state = jax.tree_map(functools.partial(refold, perm=perm), phmc_state)
phmc_state = jax.tree.map(functools.partial(refold, perm=perm), phmc_state)

if vector_step_size is None:
vector_step_size = phmc_state.state.std(1, keepdims=True)
Expand Down Expand Up @@ -1135,18 +1135,18 @@ def rejoin_folds(updated, original):
], 0), fold_to_skip, 0)

active_fold_state, phmc_extra = fun_mc.prefab.persistent_hamiltonian_monte_carlo_step(
jax.tree_map(select_folds, phmc_state),
jax.tree.map(select_folds, phmc_state),
target_log_prob_fn=target_log_prob_fn,
step_size=select_folds(scalar_step_size[:, jnp.newaxis, jnp.newaxis] *
rolled_vector_step_size),
num_integrator_steps=1,
noise_fraction=select_folds(noise_fraction)[:, jnp.newaxis, jnp.newaxis],
mh_drift=select_folds(mh_drift)[:, jnp.newaxis],
seed=phmc_seed)
phmc_state = jax.tree_map(rejoin_folds, active_fold_state, phmc_state)
phmc_state = jax.tree.map(rejoin_folds, active_fold_state, phmc_state)

# Revert the ordering of the walkers.
phmc_state = jax.tree_map(functools.partial(refold, perm=unperm), phmc_state)
phmc_state = jax.tree.map(functools.partial(refold, perm=unperm), phmc_state)

meads_state = MeadsState(
phmc_state=phmc_state,
Expand Down Expand Up @@ -1838,7 +1838,7 @@ def run_grid_element(mean_trajectory_length: jnp.ndarray,
for i in range(num_replicas):
with utils.delete_device_buffers():
res.append(
jax.tree_map(
jax.tree.map(
np.array,
_run_grid_element_impl(
seed=jax.random.fold_in(seed, i),
Expand All @@ -1853,7 +1853,7 @@ def run_grid_element(mean_trajectory_length: jnp.ndarray,
jitter_style=jitter_style,
target_accept_prob=target_accept_prob,
)))
res = jax.tree_map(lambda *x: np.stack(x, 0), *res)
res = jax.tree.map(lambda *x: np.stack(x, 0), *res)
res['mean_trajectory_length'] = mean_trajectory_length
res['damping'] = damping

Expand Down Expand Up @@ -1988,7 +1988,7 @@ def run_trial(
for i in range(num_replicas):
with utils.delete_device_buffers():
res.append(
jax.tree_map(
jax.tree.map(
np.array,
_run_trial_impl(
seed=jax.random.fold_in(seed, i),
Expand All @@ -2006,5 +2006,5 @@ def run_trial(
trajectory_length_adaptation_rate_decay=trajectory_length_adaptation_rate_decay,
save_warmup=save_warmup,
)))
res = jax.tree_map(lambda *x: np.stack(x, 0), *res)
res = jax.tree.map(lambda *x: np.stack(x, 0), *res)
return res
8 changes: 4 additions & 4 deletions discussion/meads/meads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
" unperm = jnp.eye(num_chains)[perm].argmax(0)\n",
" def refold(x, perm):\n",
" return x.reshape((num_chains,) + x.shape[2:])[perm].reshape(x.shape)\n",
" phmc_state = jax.tree_map(functools.partial(refold, perm=perm), phmc_state)\n",
" phmc_state = jax.tree.map(functools.partial(refold, perm=perm), phmc_state)\n",
"\n",
" if diagonal_preconditioning:\n",
" scale_estimates = phmc_state.state.std(1, keepdims=True)\n",
Expand Down Expand Up @@ -274,7 +274,7 @@
" fold_to_skip, 0)\n",
"\n",
" active_fold_state, phmc_extra = fun_mc.prefab.persistent_hamiltonian_monte_carlo_step(\n",
" jax.tree_map(select_folds, phmc_state),\n",
" jax.tree.map(select_folds, phmc_state),\n",
" target_log_prob_fn=target_log_prob_fn,\n",
" step_size=select_folds(step_size[:, jnp.newaxis, jnp.newaxis] *\n",
" rolled_scale_estimates),\n",
Expand All @@ -285,7 +285,7 @@
" phmc_state = jax.tree_multimap(rejoin_folds, active_fold_state, phmc_state)\n",
"\n",
" # Revert the ordering of the walkers.\n",
" phmc_state = jax.tree_map(functools.partial(refold, perm=unperm), phmc_state)\n",
" phmc_state = jax.tree.map(functools.partial(refold, perm=unperm), phmc_state)\n",
"\n",
" traced = {\n",
" 'z_chain': phmc_state.state,\n",
Expand Down Expand Up @@ -315,7 +315,7 @@
" @jit\n",
" def update_step(x, adam_state):\n",
" def g_fn(x):\n",
" return jax.tree_map(lambda x: -x, value_and_grad(target_log_prob_fn)(x))\n",
" return jax.tree.map(lambda x: -x, value_and_grad(target_log_prob_fn)(x))\n",
" tlp, g = g_fn(x)\n",
" updates, adam_state = optimizer.update(g, adam_state)\n",
" return optax.apply_updates(x, updates), adam_state, tlp\n",
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/autobnn/autobnn/kernels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_bnn_and_params(self):
linear_bnn = kernels.OneLayerBNN(width=50)
seed = jax.random.PRNGKey(0)
init_params = linear_bnn.init(seed, x_train)
constant_params = jax.tree_map(
constant_params = jax.tree.map(
lambda x: jnp.full(x.shape, 0.1), init_params)
constant_params['params']['noise_scale'] = jnp.array([0.005 ** 0.5])
return linear_bnn, constant_params, x_train, y_train
Expand Down
14 changes: 7 additions & 7 deletions spinoffs/autobnn/autobnn/training_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _init(rand_seed):
initial_state = jax.vmap(_init)(jax.random.split(seed, num_particles))
# It is okay to reuse the initial_state[0] as the test point, as Bayeux
# only uses it to figure out the treedef.
test_point = jax.tree_map(lambda t: t[0], initial_state)
test_point = jax.tree.map(lambda t: t[0], initial_state)

if for_vi:

Expand All @@ -127,7 +127,7 @@ def log_density(params, *, seed=None):
# of size [1] to the start, so we undo all of that.
del seed
return net.log_prob(
{'params': jax.tree_map(lambda x: x[0, ...], params)},
{'params': jax.tree.map(lambda x: x[0, ...], params)},
data=x_train,
observations=y_train)

Expand Down Expand Up @@ -189,9 +189,9 @@ def _filter_stuck_chains(params):
halfway_to_zero = -0.5 * stds_mu / stds_scale
unstuck = jnp.where(z_scores > halfway_to_zero)[0]
if unstuck.shape[0] > 2:
return jax.tree_map(lambda x: x[unstuck], params)
return jax.tree.map(lambda x: x[unstuck], params)
best_two = jnp.argsort(stds)[-2:]
return jax.tree_map(lambda x: x[best_two], params)
return jax.tree.map(lambda x: x[best_two], params)


@jax.named_call
Expand All @@ -214,7 +214,7 @@ def fit_bnn_vi(
seed=vi_seed, **vi_kwargs)
params = surrogate_dist.sample(seed=draw_seed, sample_shape=num_draws)

params = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), params)
params = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), params)
return params, {'loss': loss}


Expand All @@ -241,7 +241,7 @@ def fit_bnn_mcmc(
# is the easiest way to determine where "stuck chains" occur, and it is
# nice to return parameters with a single batch dimension.
params = _filter_stuck_chains(params)
params = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), params)
params = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), params)
return params, {'noise_scale': params['params'].get('noise_scale', None)}


Expand Down Expand Up @@ -430,6 +430,6 @@ def debatchify_params(params: PyTree) -> List[Dict[str, Any]]:
"""Nested dict of rank n tensors -> a list of nested dicts of rank n-1's."""
n = get_params_batch_length(params)
def get_item(i):
return jax.tree_map(lambda x: x[i, ...], params)
return jax.tree.map(lambda x: x[i, ...], params)

return [get_item(i) for i in range(n)]
8 changes: 4 additions & 4 deletions spinoffs/autobnn/autobnn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,24 @@ def make_transforms(
net: bnn.BNN,
) -> Tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]:
"""Returns unconstraining bijectors for all variables in the BNN."""
jb = jax.tree_map(
jb = jax.tree.map(
lambda x: x.experimental_default_event_space_bijector(),
net.get_all_distributions(),
is_leaf=lambda x: isinstance(x, distribution_lib.Distribution),
)

def transform(params):
return {'params': jax.tree_map(lambda p, b: b(p), params['params'], jb)}
return {'params': jax.tree.map(lambda p, b: b(p), params['params'], jb)}

def inverse_transform(params):
return {
'params': jax.tree_map(lambda p, b: b.inverse(p), params['params'], jb)
'params': jax.tree.map(lambda p, b: b.inverse(p), params['params'], jb)
}

def inverse_log_det_jacobian(params):
return jax.tree_util.tree_reduce(
lambda a, b: a + b,
jax.tree_map(
jax.tree.map(
lambda p, b: jnp.sum(b.inverse_log_det_jacobian(p)),
params['params'],
jb,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
" return targets\n",
"\n",
"def get_num_latents(target):\n",
" return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0]))))"
" return int(sum(map(np.prod, list(jax.tree.flatten(target.event_shape)[0]))))"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@
" x = x.reshape((jax.device_count(), -1, *x.shape[1:]))\n",
" return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices\n",
"\n",
"shard = functools.partial(jax.tree_map, shard_value)"
"shard = functools.partial(jax.tree.map, shard_value)"
]
},
{
Expand Down Expand Up @@ -1322,7 +1322,7 @@
"source": [
"%%time\n",
"output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))\n",
"jax.tree_map(lambda x: x.block_until_ready(), output)"
"jax.tree.map(lambda x: x.block_until_ready(), output)"
]
},
{
Expand Down Expand Up @@ -1357,7 +1357,7 @@
"source": [
"%%time\n",
"states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))\n",
"jax.tree_map(lambda x: x.block_until_ready(), trace)"
"jax.tree.map(lambda x: x.block_until_ready(), trace)"
]
},
{
Expand Down Expand Up @@ -1879,7 +1879,7 @@
"%%time\n",
"run = make_run(axis_name='data')\n",
"output = run(random.PRNGKey(0), sharded_watch_matrix)\n",
"jax.tree_map(lambda x: x.block_until_ready(), output)"
"jax.tree.map(lambda x: x.block_until_ready(), output)"
]
},
{
Expand Down Expand Up @@ -1914,7 +1914,7 @@
"source": [
"%%time\n",
"states, trace = run(random.PRNGKey(0), sharded_watch_matrix)\n",
"jax.tree_map(lambda x: x.block_until_ready(), trace)"
"jax.tree.map(lambda x: x.block_until_ready(), trace)"
]
},
{
Expand Down Expand Up @@ -2050,7 +2050,7 @@
" already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])\n",
" for i in range(500):\n",
" for j in range(2):\n",
" sample = jax.tree_map(lambda x: x[i, j], samples)\n",
" sample = jax.tree.map(lambda x: x[i, j], samples)\n",
" ranking = recommend(sample, user_id)\n",
" for movie_id in ranking:\n",
" if int(movie_id) not in already_watched:\n",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/experimental/fastgp/mbcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def loop_body(carry, _):
new_off_diags = off_diags.at[:, j - 1].set(off_diag_update)

# Only update if we are not within tolerance.
(preconditioned_errors, search_directions, alpha) = (jax.tree_map(
(preconditioned_errors, search_directions, alpha) = (jax.tree.map(
lambda o, n: jnp.where(converged, o, n),
(old_preconditioned_errors, old_search_directions, old_alpha),
(preconditioned_errors, search_directions, safe_alpha)))
Expand Down

0 comments on commit 9a84874

Please sign in to comment.