Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoE layer example #303

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Apr 30, 2024

DRAFT MODE TO PREVENT MERGES The approach and code is ready for experimentation and review.

The main result of this PR is that Thunder can run a variant of the MoE layer from LitGPT. There are three modifications

  1. zip is replaced with a for-loop with explicit indexing into lists (blocked by implement zip lookaside in Python interpreter (enables e.g. thunder.jit with zip from LitGPT LLaMAMoE) #284).
  2. Thunder doesn't support advanced indexing with None (need to create an issue). The workaround is to use unsqueeze instead of None when indexing.
  3. Inplace addition (+=) is replaced with index_add.

The main missing operator is nonzero(x, as_tuple=True). The problem with this operator is that the output shape is unknown at compile time and it's dynamic at runtime. I tried using NumberProxy with None, NumberProxy with a custom int subclass as value, using a custom int subclass directly. But simple -1 in the shape worked best.

The forward pass worked just with 14ce097. The backward pass required more of -1-special handling.

Currently, index_add, index_select, topk are not fused with any of Thunder's fusing executors.

@mruberry
Copy link
Collaborator

mruberry commented May 1, 2024

Super exciting! Really looking forward to discuss this in more detail at a design review!

@mruberry
Copy link
Collaborator

mruberry commented May 1, 2024

@t-vi and @carmocca, I think you'll be interested in this

@apaz-cli
Copy link
Collaborator

apaz-cli commented May 2, 2024

Do we have any broader ideas for how this fits into the strategy for handling dynamic and data dependent shapes? I was under the impression that this was just something we were completely incapable of doing with the way that we're modeling traces.

@@ -1048,6 +1053,7 @@ def find_producer_symbols(trace: TraceCtx, proxies: Sequence[Proxy], stop_proxie
(__b = ltorch.sub(x, y)
# __b = prims.sub(x, y),)
"""
stop_proxies = filter(lambda x: isinstance(x, Proxy), stop_proxies)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop_proxies = filter(lambda x: isinstance(x, Proxy), stop_proxies)
stop_proxies = tuple(filter(lambda x: isinstance(x, Proxy), stop_proxies))

@jjsjann123
Copy link
Collaborator

Thunder doesn't support advanced indexing with None (need to create an issue). The workaround is to use unsqueeze instead of None when indexing.

Vaguely remember that I have run into None in indexing, but I think I was just seeing that with basic indexing....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants