Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Switch transformers #19323

Merged
merged 112 commits into from Nov 15, 2022

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 4, 2022

What does this PR do?

This PR attempts to add Switch Transformers from t5x with @ArthurZucker & @thomwolf

The architecture seems to be similar to a t5 architecture (the architecture is copied from T5), where the FF layer is slightly modified, introducing the first Mixture of Experts (MoE) architecture inside transformers library.

paper: https://arxiv.org/abs/2101.03961
weights: https://github.com/google-research/t5x/blob/eb42c2524bf65c8a46624f1a9b9e034d9bc65b14/docs/models.md#converted-mesh-tensorflow-checkpoints
original modeling code: https://github.com/google/flaxformer/tree/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe

TODOs:

  • Make the forward pass run
  • Convert the weights in Pytorch format and upload them on the Hub
  • Match the logits between the original implementation and ours

@younesbelkada younesbelkada force-pushed the add_switch_transformers branch 3 times, most recently from 7e4ff1f to 1397231 Compare October 20, 2022 15:43
@younesbelkada
Copy link
Contributor Author

Thanks a lot @sgugger for your comments!
Would love to have another round of review as I added some modification for accelerate and bnb compatibility 🙏

- doctest
- remove `_keys_to_ignore_on_load_unexpected`
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great!

"""
router_probs, router_logits = self._compute_router_probabilities(hidden_states)

# Flax code for reference TODO check what happens with padded inputs here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this just in case!

Comment on lines +1512 to +1516
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can take less vertical space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Managed to change it for the attributes above in 16e7ff5 but not for this one (not sure why the decorator above is not affected by the vertical formatting though 🤔

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive contribution! Great to have so many pretrained checkpoints at release

}
return dummy_inputs

def _init_weights(self, module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive method!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor Author

Failing tests seems to be unrelated to this PR, merging!

@younesbelkada younesbelkada merged commit 163ac3d into huggingface:main Nov 15, 2022
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* first commit

* add more comments

* add router v1

* clean up

- remove `tf` modeling files

* clean up

- remove `tf` modeling files

* clean up

* v0 routers

* added more router

- Implemented `ExpertsChooseMaskedRouter`

- added tests
- 2 more routers to implement

* last router

* improved docstring

- completed the docstring in `router.py`
- added more args in the config

* v0 sparse mlp

* replace wrong naming

* forward pass run

* update MOE layer

* small router update

* fixup

* consistency

* remove scatter router

* remove abstract layer

* update test and model for integration testing

* v1 conversion

* update

* hardcode hack

* all keys match

* add gin conversion, without additional libraries

* update conversion sctipy

* delete router file

* update tests wrt router deletion

* fix router issues

* update expert code

* update, logits match, code needsREFACTORING

* Refactor code

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* add generate tests

Co-authored-by: younesbelkada <younesbelkada@gmail.com>

* add support for router loss

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix forward error

* refactor a bit

* remove `FlaxSwitchTransformers` modules

* more tests pass

* Update code

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fixup

* fix tests

* fix doc

* fix doc + tokenization

* fix tokenizer test

* fix test

* fix loss output

* update code for backward pass

* add loss support

* update documentation

* fix documentation, clean tokenizer

* more doc fix, cleanup example_switch

* fix failing test

* fix test

* fix test

* fix loss issue

* move layer

* update doc and fix router capacity usage

* fixup

* add sparse mlp index for documentation on hub

* fixup

* test sparse mix architecture

* Apply suggestions from code review

* Update docs/source/en/model_doc/switch_transformers.mdx

* fixup on update

* fix tests

* fix another test

* attempt fix

* Update src/transformers/models/switch_transformers/configuration_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* try

* all tests pass

* fix jitter noise

* Apply suggestions from code review

* doc tests pass

* Update src/transformers/models/switch_transformers/modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/switch_transformers/modeling_switch_transformers.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove assert

* change config order

* fix readme japanese

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove parallelizable tests + add one liners

* remove ONNX config

* fix nits

- add `T5Tokenizer` in auto mapping
- remove `Switch Transformers` from ONNX supported models

* remove `_get_router`

* remove asserts

* add check in test for `router_dtype`

* add `SwitchTransformersConfig` in `run_pipeline_test`

* Update tests/pipelines/test_pipelines_summarization.py

* add huge model conversion script

* fix slow tests

- add better casting for `Linear8bitLt`
- remove `torchscript` tests

* add make dir

* style on new script

* fix nits

- doctest
- remove `_keys_to_ignore_on_load_unexpected`

* Update src/transformers/models/switch_transformers/configuration_switch_transformers.py

* add google as authors

* fix year

* remove last `assert` statements

* standardize vertical spaces

* fix failing import

* fix another failing test

* Remove strange àuthorized_keys`

* removing todo and padding that is never used

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: ybelkada <younes@huggingface.co>
Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Arthur Zucker <arthur@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants