Skip to content

Commit

Permalink
Update serving code to enable saved_model=True (#18153)
Browse files Browse the repository at this point in the history
* Add serving_output and serving methods to some vision models

* Add serving outputs for DeiT

* Don't convert hidden states - differing shapes

* Make saveable

* Fix up

* Make swin saveable

* Add in tests

* Fix funnel tests (can't convert to tensor)

* Fix numpy call

* Tidy up a bit

* Add in hidden states - resnet

* Remove numpy

* Fix failing tests - tensor shape and skipping tests

* Remove duplicated function

* PR comments - formatting and var names

* PR comments
Add suggestions made by Joao Gante:
* Use tf.shape instead of shape_list
* Use @tooslow decorator on tests
* Simplify some of the logic

* PR comments
Address Yih-Dar Sheih comments - making tensor names consistent and make types float

* Types consistent with docs; disable test on swin (slow)

* CI trigger

* Change input_features to float32

* Add serving_output for segformer

* Fixup

Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
  • Loading branch information
amyeroberts and amyeroberts committed Jul 22, 2022
1 parent 0750535 commit 8e83846
Show file tree
Hide file tree
Showing 30 changed files with 471 additions and 238 deletions.
15 changes: 14 additions & 1 deletion src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ def serving(self, inputs):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)


CONVNEXT_START_DOCSTRING = r"""
Expand Down Expand Up @@ -492,6 +493,14 @@ def call(
hidden_states=outputs.hidden_states,
)

def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -584,3 +593,7 @@ def call(
logits=logits,
hidden_states=outputs.hidden_states,
)

def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=output.hidden_states)
27 changes: 25 additions & 2 deletions src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,8 @@ def serving(self, inputs):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""

return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)


DATA2VEC_VISION_START_DOCSTRING = r"""
Expand Down Expand Up @@ -910,6 +910,17 @@ def call(

return outputs

def serving_output(self, output: TFData2VecVisionModelOutputWithPooling) -> TFData2VecVisionModelOutputWithPooling:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFData2VecVisionModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -983,6 +994,12 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)


class TFData2VecVisionConvModule(tf.keras.layers.Layer):
"""
Expand Down Expand Up @@ -1443,3 +1460,9 @@ def reshape_features(x):
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSemanticSegmenterOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
36 changes: 31 additions & 5 deletions src/transformers/models/deit/modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
Expand Down Expand Up @@ -680,14 +680,14 @@ def call(
return outputs

def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
hidden_states=hidden_states,
attentions=attentions,
)


Expand Down Expand Up @@ -864,6 +864,12 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -961,6 +967,12 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output: TFImageClassifierOutput) -> TFImageClassifierOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFImageClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1041,3 +1053,17 @@ def call(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def serving_output(
self, output: TFDeiTForImageClassificationWithTeacherOutput
) -> TFDeiTForImageClassificationWithTeacherOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFDeiTForImageClassificationWithTeacherOutput(
logits=output.logits,
cls_logits=output.cls_logits,
distillation_logits=output.distillation_logits,
hidden_states=hidden_states,
attentions=attentions,
)
88 changes: 46 additions & 42 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,12 +1127,14 @@ def call(
training=training,
)

# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)


@add_start_docstrings(
Expand Down Expand Up @@ -1175,12 +1177,14 @@ def call(
training=training,
)

# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)


@add_start_docstrings(
Expand Down Expand Up @@ -1249,10 +1253,11 @@ def call(
)

def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFFunnelForPreTrainingOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFFunnelForPreTrainingOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)


@add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
Expand Down Expand Up @@ -1322,12 +1327,10 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions)


@add_start_docstrings(
Expand Down Expand Up @@ -1398,12 +1401,12 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFSequenceClassifierOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)


@add_start_docstrings(
Expand Down Expand Up @@ -1503,9 +1506,9 @@ def call(
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
Expand All @@ -1514,12 +1517,12 @@ def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:

return self.serving_output(output=output)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFMultipleChoiceModelOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)


@add_start_docstrings(
Expand Down Expand Up @@ -1592,12 +1595,12 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFTokenClassifierOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)


@add_start_docstrings(
Expand Down Expand Up @@ -1683,11 +1686,12 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
start_logits=output.start_logits,
end_logits=output.end_logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
36 changes: 23 additions & 13 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,13 @@ def _compute_mask_indices(
f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
num_masked_spans = max(num_masked_spans, min_masks)
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)

# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
num_masked_spans = tf.squeeze(num_masked_spans)

# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
Expand All @@ -256,7 +257,7 @@ def _compute_mask_indices(

# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)

return spec_aug_mask
Expand Down Expand Up @@ -1319,7 +1320,15 @@ def __init__(self, config, *inputs, **kwargs):
"to train/fine-tine this model, you need a GPU or a TPU"
)

@tf.function
@tf.function(
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)

Expand Down Expand Up @@ -1511,10 +1520,11 @@ def call(
return outputs

def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state, hidden_states=hidden_states, attentions=attentions
)


@add_start_docstrings(
Expand Down Expand Up @@ -1685,6 +1695,6 @@ def call(
)

def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)

0 comments on commit 8e83846

Please sign in to comment.