Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Generate triton block pointers for discontiguous strided tensors #125077

Open
blaine-rister opened this issue Apr 26, 2024 · 7 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@blaine-rister
Copy link
Contributor

blaine-rister commented Apr 26, 2024

馃殌 The feature, motivation and pitch

I ran the following program to test what triton code is generated from a discontiguous tensor:

import sys
import os
import logging
import torch
from torch._inductor import config as inductor_config

# Enable debug logging
os.environ["TORCH_COMPILE_DEBUG"] = "1"
torch._logging.set_logs(inductor=logging.DEBUG)

# Log to stdout
handler = logging.StreamHandler(sys.stdout)
for logger in torch._dynamo.logging.get_loggers():
   logger.addHandler(handler)

inductor_config.triton.use_block_ptr = True

def foo(x, y):
    return x + y

device = torch.device('cuda')
orig_size = (32, 32)
view_size = (32, 8)
orig = torch.randn(orig_size).to(device)
view = torch.as_strided(orig, view_size, orig.stride())

compiled_foo = torch.compile(foo, backend="inductor")
compiled_foo(view, view)

The generated kernel was:

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 8
    x1 = (xindex // 8)
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x1)), xmask)
    tmp1 = tmp0 + tmp0
    tl.store(tl.make_block_ptr(out_ptr0, shape=[256], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32), boundary_check=[0])

It seems like Inductor generates a block pointer for the output, but reverts back to standard pointers for the input. Whereas if I don't call torch.as_strided on the input, I see block pointers for both.

I am wondering if it's possible for inductor to generate something like this instead:

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[32,8], strides=[32,1], block_shape=[32,XBLOCK], order=[0], offsets=[0,xoffset])).to(tl.float32), boundary_check=[0])
    tmp1 = tmp0 + tmp0
    tl.store(tl.make_block_ptr(out_ptr0, shape=[32,8], strides=[32,1], block_shape=[32,XBLOCK], order=[0], offsets=[0,xoffset]), tl.broadcast_to(tmp1, [32,XBLOCK]).to(tl.float32), boundary_check=[0])

This would use the strides argument to tl.make_block_ptr to express that the input tensor is discontiguous. On GPUs, this could avoid the address calculation using division and modulo, which might yield some performance benefit. There is probably a much bigger win for accelerators like MTIA with simpler memory systems, where this code maps very naturally to DMA engines. Without this, simpler accelerators might have a tough time handling padding between the rows of a tensor.

Is this feature feasible? The main change I see is that here XBLOCK would refer the columns of the input matrix, as opposed to the linear index. It would also be possible to block on rows.

Alternatives

In principle, it's possible for the triton compiler to recognize this pattern under the hood. But it seems like that would require reading a whole number of rows, i.e. XBLOCK must be a multiple of the row length. Also, the analysis could get complex when division and modulo are involved. I'm wondering if makes more sense to handle this in Inductor.

Instead of block pointers, it's also possible to simplify the address calculation for standard pointers, such as

x0 = tl.broadcast_to(tl.expand_dims(tl.arange(xoffset, xoffset + XBLOCK), axis=0), [32,XBLOCK])
x1 = tl.broadcast_to(tl.expand_dims(tl.arange(32), axis=1), [32,XBLOCK])
tl.load(in_ptr0 + x0 + x1 * 32)

which could more easily be converted to a block representation inside the triton compiler.

Additional context

No response

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

cc @shunting314 based on offline conversations. We were hoping for input from @jansel .

@blaine-rister blaine-rister changed the title [Inductor] generate triton block pointers for discontiguous strided tensors [Inductor] Generate triton block pointers for discontiguous strided tensors Apr 26, 2024
@shunting314
Copy link
Contributor

I think not every non-contiguous access will cause inductor skips block_ptr.

E.g., for 'a + b.t()', here is the code inductor generates which uses block_ptr for all 3 memory accesses:

def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 1024
    xnumel = 2048
    yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[2048, 1024], strides=[1, 2048], block_shape=[XBLOCK, YBLOCK], order=[0, 1], offsets=[xoffset, yoffset]), eviction_policy='evict_last')
    tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[2048, 1024], strides=[1024, 1], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[2048, 1024], strides=[1, 2048], block_shape=[XBLOCK, YBLOCK], order=[0, 1], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32))

@blaine-rister
Copy link
Contributor Author

Thanks, this is good context. So it seems like 2D block pointers are already possible, it's just that inductor might not take advantage of them in the case of padded rows coming from torch.as_strided.

@jansel
Copy link
Contributor

jansel commented Apr 28, 2024

There is nothing special about as_strided. In that case inductor decided to generate a 1D kernel (since both dimensions had the same contiguity), but required a 2D load. Similarly, if you have a 2D kernel, but a 3D/4D load -- then block ptr won't be used.

Option 1

Change the tiling algorithm here:

def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
"""
Heuristics to decide how to tile kernels.
Currently, we tile based on stride-1 dimensions.

If you trigger a 2D tiled kernel, then block_ptr should get used.

Option 2

Generate a 2D load, then call tl.reshape. Something like:

tl.reshape(tl.load(tl.block_ptr(block_shape=[XBLOCK//8, 8], ...)), [XBLOCK])

This would require some multiple_of guards to ensure correctness.

This would be a bit more flexible.

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 29, 2024
@blaine-rister
Copy link
Contributor Author

blaine-rister commented Apr 29, 2024

Thanks @jansel for the suggestions. I can take a shot at this. Would option 2 break the requirement that tiling dims == block pointer dims? That seems preferable, but I might attempt option 1 first just to get things working.

@blaine-rister blaine-rister self-assigned this Apr 29, 2024
@jansel
Copy link
Contributor

jansel commented Apr 30, 2024

Would option 2 break the requirement that tiling dims == block pointer dims?

Yes, that is what I meant by "This would be a bit more flexible."

@blaine-rister
Copy link
Contributor Author

blaine-rister commented May 1, 2024

I think I have a reasonable draft of option 2. It pattern matches on the div/modulo indexing expression to extract the strides and offset. I'm struggling with the statically_known_multiple_of guards, though. To preserve the iteration order, it seems like we need to know that XBLOCK is a multiple of our slice size. But at least in the examples I can find, those guards seem to apply to TRITON_MAX_BLOCK["X"]. We could know that the maximum block is safe to use, but what about the minimum block?

Instead of shape guards, would it be possible to use cdiv? For example,

tl.load(tl.block_ptr(block_shape=[tl.cdiv(XBLOCK,8), 8], ...)).reshape([XBLOCK])

I think this could work if we check that the iteration ranges are all powers of 2. (Is this always true?) If dim = 2 ** n, and we know XBLOCK = 2 ** m, then CeilDiv(XBLOCK,dim) == 2 ** (m - n) if m > n else 1.

@jansel
Copy link
Contributor

jansel commented May 2, 2024

I think there are some correctness issues with that, because the iteration order must match exactly between all loads/stores in the kernel.

The guards I was talking about would need to be on the shape of the tensor being loaded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants