Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] add triton code to SchedulerNode.debug_str #125089

Open
wants to merge 1 commit into
base: gh/shunting314/140/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Expand Up @@ -79,3 +79,8 @@ def codegen_foreach(self, *args, **kwargs):

def benchmark_fused_nodes(self, nodes):
return self._triton_scheduling.benchmark_fused_nodes(nodes)

def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
)
14 changes: 10 additions & 4 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -3929,8 +3929,7 @@ def flush(self):
def ready_to_flush(self) -> bool:
return False

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
@dataclasses.dataclass
class LastUsageHolder:
n: Any
Expand Down Expand Up @@ -3962,18 +3961,25 @@ def __del__(self):
)

self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
with config.patch(
"benchmark_kernel", benchmark_kernel
), V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
else:
template_node = nodes[0]
epilogue_nodes = nodes[1:]

with config.patch("benchmark_kernel", True):
with config.patch("benchmark_kernel", benchmark_kernel):
src_code = self.codegen_template(
template_node, epilogue_nodes, only_gen_src_code=True
)

src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
mod = PyCodeCache.load(src_code)

def cache_file_path():
Expand Down
17 changes: 16 additions & 1 deletion torch/_inductor/scheduler.py
Expand Up @@ -737,6 +737,13 @@ def debug_str_extra(self) -> str:
if isinstance(self._body, ir.LoopBody):
lines.append(f"class {name}_loop_body:")
lines.append(textwrap.indent(self._body.debug_str(), " "))

if ir.is_triton(self.node.get_device()):
backend = self.scheduler.get_backend(self.node.get_device())
V.graph.scheduler.current_device = self.node.get_device()
triton_code = backend.generate_kernel_code_from_nodes((self,)).strip()
lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))
return "\n".join(lines)

def get_ranges(self):
Expand Down Expand Up @@ -894,6 +901,14 @@ def debug_str_extra(self) -> str:
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
for i, node in enumerate(self.snodes)
]
device = self.snodes[0].node.get_device()
if ir.is_triton(device):
backend = self.scheduler.get_backend(device)
V.graph.scheduler.current_device = device
triton_code = backend.generate_kernel_code_from_nodes(self.snodes).strip()
lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))

return textwrap.indent("\n".join(lines).rstrip(), " ")

def set_last_usage(
Expand Down Expand Up @@ -1265,6 +1280,7 @@ class Scheduler:
@dynamo_timed
def __init__(self, nodes):
super().__init__()
V.graph.scheduler = self
self.backends = {}
self.fuse_cache = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
Expand Down Expand Up @@ -1718,7 +1734,6 @@ def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]:
"""
assert len(nodes) > 0
device = nodes[0].get_device()
V.graph.scheduler = self
self.current_device = device
backend = self.get_backend(device)
return backend.benchmark_fused_nodes(nodes)
Expand Down