Skip to content

Commit

Permalink
[DLight] Check for target in function attributes (#16958)
Browse files Browse the repository at this point in the history
Prior to this commit, the `dlight` scheduling rules were applied
solely based on the global `tvm.target.Target.current()`.  However, a
TIR PrimFunc may be annotated with the target, rather than using the
global `Target.current()`.  In this case, the `dlight` scheduling
may produce a scheduled PrimFunc that is not compatible with its
target.  For example, using a thread binding to `"threadIdx.x"` on a
CPU target.

This commit updates `dlight` to check for a TIR PrimFunc's annotations
when scheduling, matching the behavior of `tvm.build`.
  • Loading branch information
Lunderberg committed May 13, 2024
1 parent 2933744 commit eb242ec
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def _is_scheduled(func: tir.PrimFunc) -> bool:
return func.attrs["tir.is_scheduled"] == 1


def _get_target(func: tir.PrimFunc) -> Target:
target = func.attrs.get("target")
if target is None:
return Target.current(allow_none=False)
else:
return target


@module_pass(opt_level=0, name="ApplyDefaultSchedule")
class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods
"""A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module."""
Expand All @@ -55,10 +63,11 @@ def transform_module( # pylint: disable=missing-function-docstring
mod: IRModule,
_: PassContext,
) -> IRModule:
target = Target.current(allow_none=False)
updated_functions = {}
for g_var, func in mod.functions_items():
if isinstance(func, tir.PrimFunc) and not _is_scheduled(func):
target = _get_target(func)

sch = _apply_rules(func, target, self.rules, tunable=False)
if sch is not None:
assert len(sch) == 1
Expand Down
78 changes: 78 additions & 0 deletions tests/python/dlight/test_gpu_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,83 @@ def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_tabl
assert_structural_equal(mod["main"], expected)


def test_gpu_fallback_ignores_non_gpu_functions():
@I.ir_module
class Before:
# This function has no "target" attribute, and is scheduled
# using the `Target.current`.
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

# This function is identical, except that it is explicitly
# annotated with the "target" attribute, and is scheduled
# based on the annotation's target.
@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

@I.ir_module
class After:
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"tir.is_scheduled": 1})
for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1)
T.reads(A[0, v0 // 128, 0, v0 % 128])
T.writes(C[0, 0, v0])
C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]

@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]

with Target("cuda"):
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(Before)
assert_structural_equal(mod, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit eb242ec

Please sign in to comment.