Skip to content

Commit

Permalink
[inductor] add triton code to SchedulerNode.debug_str
Browse files Browse the repository at this point in the history
ghstack-source-id: b447449e94ba4c59a2c6c3d6304bfa4c04a18dcd
Pull Request resolved: pytorch/pytorch#125091
  • Loading branch information
shunting314 committed Apr 29, 2024
1 parent f6ed068 commit 81e4190
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Expand Up @@ -75,3 +75,8 @@ def codegen_foreach(self, *args, **kwargs):

def benchmark_fused_nodes(self, nodes):
return self._triton_scheduling.benchmark_fused_nodes(nodes)

def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
)
14 changes: 10 additions & 4 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -3921,8 +3921,7 @@ def flush(self):
def ready_to_flush(self) -> bool:
return False

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
@dataclasses.dataclass
class LastUsageHolder:
n: Any
Expand Down Expand Up @@ -3954,18 +3953,25 @@ def __del__(self):
)

self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
with config.patch(
"benchmark_kernel", benchmark_kernel
), V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
else:
template_node = nodes[0]
epilogue_nodes = nodes[1:]

with config.patch("benchmark_kernel", True):
with config.patch("benchmark_kernel", benchmark_kernel):
src_code = self.codegen_template(
template_node, epilogue_nodes, only_gen_src_code=True
)

src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
mod = PyCodeCache.load(src_code)

def cache_file_path():
Expand Down
26 changes: 25 additions & 1 deletion torch/_inductor/scheduler.py
Expand Up @@ -743,6 +743,20 @@ def debug_str_extra(self) -> str:
if isinstance(self._body, ir.LoopBody):
lines.append(f"class {name}_loop_body:")
lines.append(textwrap.indent(self._body.debug_str(), " "))

if ir.is_triton(self.node.get_device()):
backend = self.scheduler.get_backend(self.node.get_device())
V.graph.scheduler.current_device = self.node.get_device()

# Don't increment kernel count when generating debug string.
# This will confuse some unit tests that check the number of
# generated kernels.
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes((self,)).strip()
metrics.generated_kernel_count = old_generated_kernel_count

lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))
return "\n".join(lines)

def get_ranges(self):
Expand Down Expand Up @@ -900,6 +914,16 @@ def debug_str_extra(self) -> str:
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
for i, node in enumerate(self.snodes)
]
device = self.snodes[0].node.get_device()
if ir.is_triton(device):
backend = self.scheduler.get_backend(device)
V.graph.scheduler.current_device = device
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes(self.snodes).strip()
metrics.generated_kernel_count = old_generated_kernel_count
lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))

return textwrap.indent("\n".join(lines).rstrip(), " ")

def set_last_usage(
Expand Down Expand Up @@ -1271,6 +1295,7 @@ class Scheduler:
@dynamo_timed
def __init__(self, nodes):
super().__init__()
V.graph.scheduler = self
self.backends = {}
self.fuse_cache = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
Expand Down Expand Up @@ -1734,7 +1759,6 @@ def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]:
"""
assert len(nodes) > 0
device = nodes[0].get_device()
V.graph.scheduler = self
self.current_device = device
backend = self.get_backend(device)
return backend.benchmark_fused_nodes(nodes)
Expand Down

0 comments on commit 81e4190

Please sign in to comment.