Skip to content

Commit

Permalink
TF: correct TFBart embeddings weights name when load_weight_prefix is…
Browse files Browse the repository at this point in the history
… passed (huggingface#18993)
  • Loading branch information
gante authored and oneraghavan committed Sep 26, 2022
1 parent cd98710 commit 87aa93d
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import random
from contextlib import nullcontext
from typing import Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -748,7 +749,15 @@ def call(
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
Expand Down Expand Up @@ -936,7 +945,15 @@ def call(
positions = self.embed_positions(input_shape, position_ids=position_ids)

if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

hidden_states = inputs_embeds
Expand Down Expand Up @@ -1032,8 +1049,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared")
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix

self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
Expand Down

0 comments on commit 87aa93d

Please sign in to comment.