New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Switch transformers #19323
Add Switch transformers #19323
Conversation
- Implemented `ExpertsChooseMaskedRouter` - added tests - 2 more routers to implement
- completed the docstring in `router.py` - added more args in the config
…to add_switch_transformers
7e4ff1f
to
1397231
Compare
1397231
to
6ede608
Compare
…lkada/transformers into add_switch_transformers
- add better casting for `Linear8bitLt` - remove `torchscript` tests
…lkada/transformers into add_switch_transformers
Thanks a lot @sgugger for your comments! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great!
src/transformers/models/switch_transformers/configuration_switch_transformers.py
Outdated
Show resolved
Hide resolved
""" | ||
router_probs, router_logits = self._compute_router_probabilities(hidden_states) | ||
|
||
# Flax code for reference TODO check what happens with padded inputs here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flagging this just in case!
src/transformers/models/switch_transformers/modeling_switch_transformers.py
Outdated
Show resolved
Hide resolved
_keys_to_ignore_on_load_missing = [ | ||
r"encoder.embed_tokens.weight", | ||
r"decoder.embed_tokens.weight", | ||
r"lm_head.weight", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can take less vertical space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Managed to change it for the attributes above in 16e7ff5 but not for this one (not sure why the decorator above is not affected by the vertical formatting though 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive contribution! Great to have so many pretrained checkpoints at release
src/transformers/models/switch_transformers/configuration_switch_transformers.py
Outdated
Show resolved
Hide resolved
} | ||
return dummy_inputs | ||
|
||
def _init_weights(self, module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive method!
src/transformers/models/switch_transformers/configuration_switch_transformers.py
Outdated
Show resolved
Hide resolved
…ch_transformers.py
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Failing tests seems to be unrelated to this PR, merging! |
* first commit * add more comments * add router v1 * clean up - remove `tf` modeling files * clean up - remove `tf` modeling files * clean up * v0 routers * added more router - Implemented `ExpertsChooseMaskedRouter` - added tests - 2 more routers to implement * last router * improved docstring - completed the docstring in `router.py` - added more args in the config * v0 sparse mlp * replace wrong naming * forward pass run * update MOE layer * small router update * fixup * consistency * remove scatter router * remove abstract layer * update test and model for integration testing * v1 conversion * update * hardcode hack * all keys match * add gin conversion, without additional libraries * update conversion sctipy * delete router file * update tests wrt router deletion * fix router issues * update expert code * update, logits match, code needsREFACTORING * Refactor code Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * add generate tests Co-authored-by: younesbelkada <younesbelkada@gmail.com> * add support for router loss Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fix forward error * refactor a bit * remove `FlaxSwitchTransformers` modules * more tests pass * Update code Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fixup * fix tests * fix doc * fix doc + tokenization * fix tokenizer test * fix test * fix loss output * update code for backward pass * add loss support * update documentation * fix documentation, clean tokenizer * more doc fix, cleanup example_switch * fix failing test * fix test * fix test * fix loss issue * move layer * update doc and fix router capacity usage * fixup * add sparse mlp index for documentation on hub * fixup * test sparse mix architecture * Apply suggestions from code review * Update docs/source/en/model_doc/switch_transformers.mdx * fixup on update * fix tests * fix another test * attempt fix * Update src/transformers/models/switch_transformers/configuration_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * try * all tests pass * fix jitter noise * Apply suggestions from code review * doc tests pass * Update src/transformers/models/switch_transformers/modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/switch_transformers/modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * remove assert * change config order * fix readme japanese * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove parallelizable tests + add one liners * remove ONNX config * fix nits - add `T5Tokenizer` in auto mapping - remove `Switch Transformers` from ONNX supported models * remove `_get_router` * remove asserts * add check in test for `router_dtype` * add `SwitchTransformersConfig` in `run_pipeline_test` * Update tests/pipelines/test_pipelines_summarization.py * add huge model conversion script * fix slow tests - add better casting for `Linear8bitLt` - remove `torchscript` tests * add make dir * style on new script * fix nits - doctest - remove `_keys_to_ignore_on_load_unexpected` * Update src/transformers/models/switch_transformers/configuration_switch_transformers.py * add google as authors * fix year * remove last `assert` statements * standardize vertical spaces * fix failing import * fix another failing test * Remove strange àuthorized_keys` * removing todo and padding that is never used Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: ybelkada <younes@huggingface.co> Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Arthur Zucker <arthur@huggingface.co>
What does this PR do?
This PR attempts to add Switch Transformers from t5x with @ArthurZucker & @thomwolf
The architecture seems to be similar to a t5 architecture (the architecture is copied from T5), where the FF layer is slightly modified, introducing the first Mixture of Experts (MoE) architecture inside
transformers
library.paper: https://arxiv.org/abs/2101.03961
weights: https://github.com/google-research/t5x/blob/eb42c2524bf65c8a46624f1a9b9e034d9bc65b14/docs/models.md#converted-mesh-tensorflow-checkpoints
original modeling code: https://github.com/google/flaxformer/tree/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe
TODOs: