Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distilbert-flax #13324

Merged
merged 9 commits into from Aug 30, 2021
Merged

Conversation

kamalkraj
Copy link
Contributor

What does this PR do?

DistilBert Flax

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@VictorSanh @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Great Great job @kamalkraj!

Think the only major thing to update is to docs in the modeling file (at the moment it looks like it's the PyTorch docs, but should be Flax :-))

@kamalkraj
Copy link
Contributor Author

@patrickvonplaten
Thanks for the review.
Done changes according to your review.

@stefan-it
Copy link
Collaborator

stefan-it commented Aug 30, 2021

Hi @kamalkraj , I'm also really interested in that PR - thanks for adding it 🤗

Do you also plan to add a script for the distillation process (like it is done in the "old" script), as I would like to re-distillate some of my previous DistilBERT models (I don't have access to multi GPU setups, only to TPUs at the moment).

@kamalkraj
Copy link
Contributor Author

Hi @stefan-it,

I will go through the scripts and pings you.
I have multi-GPU access. Which TPU do you use? v3-8 ?

@patrickvonplaten
Copy link
Contributor

JAX_PLATFORM_NAME=cpu RUN_SLOW=1 pytest tests/test_modeling_flax_distilbert.py::FlaxDistilBertModelIntegrationTest::test_inference_no_head_absolute_embedding

passes and the code looks good :-) Ready to merge IMO 🎉 !

@patil-suraj the slow test doesn't pass on TPU since distilbert has pretty extreme activations in the forward pass like a couple of other models. We need to think a bit how to adapt the slow test depending on whether they're run on TPU or not in general...

@patrickvonplaten
Copy link
Contributor

Great work @kamalkraj !

@patrickvonplaten patrickvonplaten merged commit 774760e into huggingface:master Aug 30, 2021
@kamalkraj kamalkraj deleted the distilbert-flax branch September 11, 2021 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants