Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF port of ESM #19587

Merged
merged 23 commits into from Oct 17, 2022
Merged

TF port of ESM #19587

merged 23 commits into from Oct 17, 2022

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Oct 13, 2022

Working out the last few issues now! Models <3B parameters have been ported already, larger models will need to wait for #19124.

This PR also includes fixes for a couple of issues in the original PyTorch ESM.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1
Copy link
Member Author

Pipeline tests are failing because the model has no SEP token and doesn't work with multiple sequences. Working on it!

@Rocketknight1
Copy link
Member Author

There's one final test remaining that's failing because of some arcane issue in the code that generates data batches for the pipeline. I'm trying to figure it out!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean, thanks a lot for porting this model in TensorFlow!

@@ -42,12 +42,14 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/esm-1b"
_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need an update :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, all of these will be moved to facebook before the next release!

Comment on lines +861 to +864
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those long comments make review very hard in GitHub.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one's copied from BERT!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth fixing on a followup PR then!

@Rocketknight1
Copy link
Member Author

Tests are green, and #19124 has been merged! Going to use it to upload the remaining checkpoints and then merge this.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥🔥🔥

(Now that I've reviewed this PR, does it mean I can get a job in the biotech industry? :P )

Comment on lines +83 to +88
# Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation
# and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at
# all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the
# original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that
# the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our
# models give different outputs from the original.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I got it right: we want to load inv_freq as a weight when it exists, because it was stored in float16. If we were to use the float32 values, we would get different outputs. Correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also - does XLA automatically create constant caches when appropriate? 😱

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it does! And if not, it can compute this during the 'downtime' of other small tasks once it's compiled - it's a really small tensor!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, you're correct about the float16/float32 issue. I was getting divergent outputs in my port at first because I recomputed the value rather than loading it from the checkpoint.

src/transformers/models/esm/modeling_tf_esm.py Outdated Show resolved Hide resolved
src/transformers/models/esm/modeling_tf_esm.py Outdated Show resolved Hide resolved
src/transformers/models/esm/modeling_tf_esm.py Outdated Show resolved Hide resolved
Comment on lines 751 to 753
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that get_input_embeddings returns self.embeddings.word_embeddings, I'm assuming that this function should overwrite self.embeddings.word_embeddings and value is of type Embedding - right?

(like set_output_embeddings below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, good catch!

@@ -0,0 +1,287 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs an update :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Rocketknight1 and others added 3 commits October 17, 2022 13:33
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@Rocketknight1 Rocketknight1 merged commit 3b3024d into main Oct 17, 2022
@Rocketknight1 Rocketknight1 deleted the esm_tf_port branch October 17, 2022 13:16
kashif pushed a commit to kashif/transformers that referenced this pull request Oct 21, 2022
* Partial TF port for ESM model

* Add ESM-TF tests

* Add the various imports for TF-ESM

* TF weight conversion almost ready

* Stop ignoring the decoder weights in PT

* Add tests and lots of fixes

* fix-copies

* Fix imports, add model docs

* Add get_vocab() to tokenizer

* Fix vocab links for pretrained files

* Allow multiple inputs with a sep

* Use EOS as SEP token because ESM vocab lacks SEP

* Correctly return special tokens mask from ESM tokenizer

* make fixup

* Stop testing unsupported embedding resizing

* Handle TF bias correctly

* Skip all models with slow tokenizers in the token classification test

* Fixing the batch/unbatcher of pipelines to accomodate the `None` being

passed around.

* Fixing pipeline bug caused by slow tokenizer  being different.

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update set_input_embeddings and the copyright notices

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants