Skip to content

Commit

Permalink
Remove question comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Aug 13, 2021
1 parent b6f5f4f commit a36d7c3
Showing 1 changed file with 0 additions and 27 deletions.
Expand Up @@ -305,20 +305,12 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
**kwargs
):
# Q: unlike in `modeling_flax_bart.FlaxBartPreTrainedModel.__init__`, here we use a more general shape
# (keep vision models' argument `pixel_values` as a next step).
if input_shape is None:
input_shape = ((1, 1), (1, 1))

module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

# Copied from modeling.encoder_decoder.EncoderDecoderModel
# Q: should we keep this (modified) check - because we don't have `get_output_embeddings` as in PyTorch models.
assert not hasattr(
self.module._get_encoder_module(), "lm_head"
), "The encoder {} should not have a LM Head. Please use a model without LM Head"

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape

Expand All @@ -332,7 +324,6 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
# Q: Is it ok to keep this sanity check?
assert (
decoder_batch_size == batch_size
), f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
Expand Down Expand Up @@ -389,15 +380,13 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
# Q: are we sure all accepted decoder modules have the `init_cache` argument?
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])

@add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
# Q: how to deal with kwargs, like `token_type_ids`?
def encode(
self,
input_ids: jnp.ndarray,
Expand Down Expand Up @@ -637,15 +626,6 @@ def __call__(
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

# prepare decoder inputs
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
# Q: might not have these attributes in the decoder
# Q: if the encoder is a vision model, we might need `jnp.ones(batch_size, 1)`.
input_ids,
self.config.pad_token_id,
decoder_start_token_id=self.config.decoder_start_token_id,
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
Expand Down Expand Up @@ -681,13 +661,6 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):

# Q: In `modeling_encoder_deocer.EncoderDecoder.prepare_inputs_for_generation()`, we use the decoder model's
# `prepare_inputs_for_generation()`. However, in Flax's version, we only have the decode module, not the
# decoder model. Is the implementation here OK?
# original version in `modeling_encoder_deocer.EncoderDecoder`
# decoder_inputs = self.decoder.prepare_inputs_for_generation(decoder_input_ids, ...)

# initializing the cache
batch_size, seq_length = decoder_input_ids.shape

Expand Down

0 comments on commit a36d7c3

Please sign in to comment.