Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] add triton code to SchedulerNode.debug_str #125091

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Expand Up @@ -75,3 +75,8 @@ def codegen_foreach(self, *args, **kwargs):

def benchmark_fused_nodes(self, nodes):
return self._triton_scheduling.benchmark_fused_nodes(nodes)

def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
)
14 changes: 10 additions & 4 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -3921,8 +3921,7 @@ def flush(self):
def ready_to_flush(self) -> bool:
return False

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
@dataclasses.dataclass
class LastUsageHolder:
n: Any
Expand Down Expand Up @@ -3954,18 +3953,25 @@ def __del__(self):
)

self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
with config.patch(
"benchmark_kernel", benchmark_kernel
), V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
else:
template_node = nodes[0]
epilogue_nodes = nodes[1:]

with config.patch("benchmark_kernel", True):
with config.patch("benchmark_kernel", benchmark_kernel):
src_code = self.codegen_template(
template_node, epilogue_nodes, only_gen_src_code=True
)

src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code

@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
mod = PyCodeCache.load(src_code)

def cache_file_path():
Expand Down
26 changes: 25 additions & 1 deletion torch/_inductor/scheduler.py
Expand Up @@ -743,6 +743,20 @@ def debug_str_extra(self) -> str:
if isinstance(self._body, ir.LoopBody):
lines.append(f"class {name}_loop_body:")
lines.append(textwrap.indent(self._body.debug_str(), " "))

if ir.is_triton(self.node.get_device()):
backend = self.scheduler.get_backend(self.node.get_device())
V.graph.scheduler.current_device = self.node.get_device()

# Don't increment kernel count when generating debug string.
# This will confuse some unit tests that check the number of
# generated kernels.
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes((self,)).strip()
metrics.generated_kernel_count = old_generated_kernel_count

lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))
return "\n".join(lines)

def get_ranges(self):
Expand Down Expand Up @@ -900,6 +914,16 @@ def debug_str_extra(self) -> str:
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
for i, node in enumerate(self.snodes)
]
device = self.snodes[0].node.get_device()
if ir.is_triton(device):
backend = self.scheduler.get_backend(device)
V.graph.scheduler.current_device = device
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes(self.snodes).strip()
metrics.generated_kernel_count = old_generated_kernel_count
lines.append(f"{self.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))

return textwrap.indent("\n".join(lines).rstrip(), " ")

def set_last_usage(
Expand Down Expand Up @@ -1271,6 +1295,7 @@ class Scheduler:
@dynamo_timed
def __init__(self, nodes):
super().__init__()
V.graph.scheduler = self
self.backends = {}
self.fuse_cache = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
Expand Down Expand Up @@ -1734,7 +1759,6 @@ def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]:
"""
assert len(nodes) > 0
device = nodes[0].get_device()
V.graph.scheduler = self
self.current_device = device
backend = self.get_backend(device)
return backend.benchmark_fused_nodes(nodes)
Expand Down