Skip to content

Commit

Permalink
[XLA:TPU] Support output streaming and refactor TryOutputStreaming in…
Browse files Browse the repository at this point in the history
…to a bottoms-up approach.

Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph.

PiperOrigin-RevId: 630885979
  • Loading branch information
jvstokes authored and jax authors committed May 10, 2024
1 parent a9460f2 commit ade474c
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,22 @@ def body(carry, x):
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
self.assertEqual(out_hbm.sharding, out_s)

def test_output_streaming(self):
mesh = jtu.create_global_mesh((1, 1), ("x", "y"))
np_inp = np.arange(16.0).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
arr_hbm = jax.device_put(np_inp, s_hbm)

@functools.partial(jax.jit, out_shardings=s_host)
def f(xs):
out_tpu = xs + 1.0
return out_tpu

out_host = f(arr_hbm)
self.assertArraysEqual(out_host, np_inp + 1.0)
self.assertEqual(out_host.sharding, s_host)


class ActivationOffloadingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit ade474c

Please sign in to comment.