Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSE pass should work on symbols with keyword arguments #397

Open
IvanYashchuk opened this issue May 10, 2024 · 0 comments
Open

CSE pass should work on symbols with keyword arguments #397

IvanYashchuk opened this issue May 10, 2024 · 0 comments
Assignees
Labels
bug Something isn't working optimization passes

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented May 10, 2024

馃殌 Feature

Motivation

Currently, if kwargs is of nonzero length then all bound symbols are set to be not equal to each other:

if len(self.parent.kwargs) > 0 or len(other.parent.kwargs) > 0:
return False

This needs to be changed for CSE to work properly.

Here's an example where CSE fails to eliminate computation because there's a keyword argument:

import torch
import thunder

@thunder.jit
def func(x):
    t1 = thunder.prims.var(x, (0, 1), correction=1)
    t2 = thunder.prims.var(x, (0, 1), correction=1)
    t3 = thunder.prims.add(t1, t2)
    return t3

x = torch.randn(512, 512, device="cuda")
out = func(x)
print(thunder.last_traces(func)[-1])

execution trace:

@torch.no_grad()
@no_autocast()
def func(x):
  # x: "cuda:0 f32[512, 512]" 
  [t4] = nvFusion0(x)
    # t1 = prims.var(x, (0, 1), correction=1)  # t1: "cuda:0 f32[]"
    # t3 = prims.var(x, (0, 1), correction=1)  # t3: "cuda:0 f32[]"
    # t4 = prims.add(t1, t3)  # t4: "cuda:0 f32[]"
  del x
  return t4

One key data structure is a "frozen dict", dictionary that is hashable and immutable, it was added in 388a1a8.

cc @apaz-cli

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working optimization passes
Projects
None yet
Development

No branches or pull requests

3 participants