Skip to content

Commit

Permalink
Add support to pipeline emitter for shapes that don't perfectly divid…
Browse files Browse the repository at this point in the history
…e the block shapes

PiperOrigin-RevId: 631594796
  • Loading branch information
sharadmv authored and jax authors committed May 9, 2024
1 parent a9460f2 commit 98efe9e
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 34 deletions.
136 changes: 122 additions & 14 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enum
import functools
import itertools
import math
import operator
from typing import Optional, Union, Any, Sequence

Expand All @@ -44,6 +45,9 @@
PipelineRefs = Union[Sequence[REF], Any]


# TODO(sharadmv): make this a parameter and make it queryable from the Device.
_TILING = (8, 128)

def _broadcast_pytree_to(from_pytree, to_pytree):
"""Broadcast a prefix pytree to a given full tree."""
proxy = object()
Expand All @@ -63,14 +67,56 @@ def add_leaves(i, x):
return tree_util.tree_unflatten(treedef, broadcast_leaves)


def _make_tiling(shape: tuple[int, ...]) -> tuple[int, ...]:
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
# (2, 3, 128, 128) -> (1, 1, 8, 128).
return (*(1,) * (len(shape) - 2), *_TILING[-len(shape) :])


def _mod(a, n):
""""Calculates a mod n for positive and negative a with |a| <= n."""
return lax.rem(a + n, n)


def _make_ds(idx, size):
def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
if s % multiple == 0:
return s
# Subtract off the remainder, then add multiple
return s - s % multiple + multiple


def _make_ds(
idx: jax.Array | int, size: jax.Array | int
) -> pl.Slice:
"""Make a DMA slice with mosaic size hints."""
return pl.ds(pl.multiple_of(idx * size, size), size)
offset = idx * size
if isinstance(size, int):
offset = pl.multiple_of(offset, size)
return pl.ds(offset, size)


def _make_block_slice(
block_index: jax.Array, block_size: int, size: int, tiling: int
) -> pl.Slice:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
# However, if the total size of the ref does not divide block size and we are
# selecting the last block, we need to pick the lowest tiling size multiple
# that contains the block.
if size % block_size == 0:
return _make_ds(block_index, block_size)
assert block_size % tiling == 0
num_blocks = math.ceil(size / block_size)
is_last = block_index == num_blocks - 1
size = jnp.where(
is_last,
_round_up_to_nearest_multiple(size % block_size, tiling),
block_size,
)
size = pl.multiple_of(size, tiling)
offset = pl.multiple_of(block_index * block_size, block_size)
return pl.ds(offset, size)


def _tuples_differ(xs, ys):
Expand Down Expand Up @@ -259,49 +305,111 @@ def swap_slots(self):
if self.memory_space == VMEM: return
self.current_slot[0] = self.next_slot[0]

def get_dma_slice(self, src_shape, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
# block looks like this (for array shape (600, 600) and block shape
# (256, 256)):
#
# +--------------+------------------|
# | Block (0,0) | |
# | (256, 256) | |
# +--------------+ |
# | A (600, 600) |
# | |
# +---------------------------------+
#
# For in-bounds blocks, we don't need to do anything special.
# An out-of-bounds block looks like this:
#
# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# | XXXXXXXXXX |
# +--------------+
# where the X's indicate where the block is out of bounds.
#
# When we have an out of bounds block like this, we need to truncate it to
# a tile boundary (tiles are (8, 128) along the last two dimensions).
# Specifically, in this case, we'll have a block that is indexing the
# 512:768 elements of A along the first dimension. We need to convert 768
# into 600 (600 % 8 == 0), so our indexing will look like this:

# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# where it is now a (128, 256) sized block.

# Suppose A is now (601, 600), instead of picking a (128, 256)-sized block
# for the last iteration on that dimension, we will pick the next highest
# tile multiple, i.e. (136, 256).
tiling = _make_tiling(src_shape)
block_shape = tuple(1 if b is None else b for b in self.block_shape)
block_indices = self.compute_index(*grid_indices)
return jax.tree.map(
_make_block_slice, block_indices, block_shape, src_shape, tiling
)

def divides_block_shape(self, shape) -> tuple[bool, ...]:
return tuple(True if b is None else s % b == 0 for
s, b in zip(shape, self.block_shape))

def copy_in(self, src_ref, grid_indices):
"""Starts copy of HBM dma slice into the current slot."""
assert self.is_input
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
next_slot = lax.rem(self.current_slot[0] + 1, 2)
self.next_slot[0] = next_slot
src_slice = self.get_dma_slice(src_ref.shape, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[dma_slice],
self.vmem_ref.at[next_slot],
src_ref.at[src_slice],
self.vmem_ref.at[next_slot].at[dst_slice],
self.sem_recv).start()

def copy_out(self, dst_ref, grid_indices):
"""Starts copy of HBM dma slice from the current slot."""
assert self.is_output
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
slot = self.current_slot[0]
self.next_slot[0] = lax.rem(slot + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[slot],
dst_ref.at[dma_slice],
self.vmem_ref.at[slot].at[src_slice],
dst_ref.at[dst_slice],
self.sem_send).start()

def wait_in(self, src_ref, grid_indices):
"""Waits for input copy to finish."""
assert self.is_input
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
src_slice = self.get_dma_slice(src_ref.shape, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[dma_slice], # nb: doesn't matter
self.vmem_ref.at[self.current_slot[0]], # only dst shape is important
src_ref.at[src_slice], # nb: doesn't matter
self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important
self.sem_recv).wait()

def wait_out(self, dst_ref, grid_indices):
"""Waits for output copy to finish."""
assert self.is_output
if self.memory_space == VMEM: return
dma_slice = self.compute_slice(grid_indices)
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[prev_slot], # nb: doesn't matter
dst_ref.at[dma_slice], # only dst shape is important
self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
dst_ref.at[dst_slice], # only dst shape is important
self.sem_send).wait()

# Accumulator methods
Expand Down
11 changes: 9 additions & 2 deletions jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Pallas utility functions."""

from __future__ import annotations
from typing import overload

import jax
from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
Expand All @@ -32,9 +34,14 @@ def _wrapped(f):
lax.cond(condition, f, lambda: None)
return _wrapped


@overload
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
...

def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array:
if isinstance(a, int) and isinstance(b, int):
return (a + b - 1) // b
return lax.div(a + b - 1, b)


def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
Expand Down
3 changes: 2 additions & 1 deletion tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ jax_test(
"gpu",
],
main = "pallas_pipeline_tpu_test.py",
shard_count = 2,
deps = [
"//jax:extend",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
],
] + py_deps("hypothesis"),
)

jax_test(
Expand Down

0 comments on commit 98efe9e

Please sign in to comment.