Skip to content

Commit

Permalink
Added Whisper model to exporters
Browse files Browse the repository at this point in the history
* added whisper to exporters
* Removed reduntant code
* Added io binding for ORTModelForSpeechSeq2Seq
  • Loading branch information
mht-sharma committed Nov 9, 2022
1 parent 9e6accf commit 2d0a7ca
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 166 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC):
"seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}

def __init__(
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Mapping

from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -99,3 +100,15 @@ class VisionOnnxConfig(OnnxConfig):

class TextAndVisionOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator)


class AudioOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator,)


class TextAndAudioOnnxConfig(Seq2SeqOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyAudioInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
4 changes: 3 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def export_pytorch(
model.to(device)
dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
# inputs = config.ordered_inputs(model)
inputs = config.inputs
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

Expand All @@ -125,6 +126,7 @@ def export_pytorch(
else:
# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.

onnx_export(
model,
(dummy_inputs,),
Expand Down
131 changes: 130 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
NormalizedAudioConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
)
from .base import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import DecoderOnnxConfig, EncoderOnnxConfig, Seq2SeqOnnxConfig, TextAndVisionOnnxConfig, VisionOnnxConfig
from .config import (
AudioOnnxConfig,
DecoderOnnxConfig,
EncoderOnnxConfig,
Seq2SeqOnnxConfig,
TextAndAudioOnnxConfig,
TextAndVisionOnnxConfig,
VisionOnnxConfig,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -624,3 +633,123 @@ def generate_dummy_inputs(self, framework: str = "pt"):
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
return dummy_inputs


class WhisperOnnxConfig(TextAndAudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs


class Seq2SeqDecoderTextInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"input_ids",
"encoder_attention_mask",
"encoder_hidden_states",
)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = 2,
sequence_length: int = 16,
num_choices: int = 4,
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
):
super().__init__(
task,
normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
random_num_choices_range=random_num_choices_range,
)

self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt"):
if input_name != "encoder_hidden_states":
dummy_input = super().generate(input_name, framework=framework)
else:
shape = (self.batch_size, self.sequence_length, self.hidden_size)
dummy_input = (self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework), None, None)

return dummy_input


class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedAudioConfig

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}


class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

DUMMY_INPUT_GENERATOR_CLASSES = (
Seq2SeqDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
from collections import OrderedDict

common_inputs = {
"input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super(OnnxConfigWithPast, self).outputs
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

def generate_dummy_inputs(self, framework: str = "pt"):
dummy_input = super().generate_dummy_inputs(framework)

common_inputs = {}
common_inputs["decoder_input_ids"] = dummy_input.pop("input_ids")
common_inputs["encoder_outputs"] = dummy_input.pop("encoder_hidden_states")

if "past_key_values" in dummy_input:
common_inputs["past_key_values"] = dummy_input.pop("past_key_values")

return common_inputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": True}

return None
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
)
if is_tf_available():
Expand Down Expand Up @@ -122,6 +123,7 @@ class TasksManager:
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -506,6 +508,13 @@ class TasksManager:
onnx="T5OnnxConfig",
),
"vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"),
"whisper": supported_tasks_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx="WhisperOnnxConfig",
),
"xlm": supported_tasks_mapping(
"default",
"masked-lm",
Expand Down
75 changes: 4 additions & 71 deletions optimum/onnx/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ class DecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("encoder_attention_mask", {0: "batch", 1: "encoder_sequence"}),
("input_ids", {0: "batch", 1: "past_decoder_sequence + sequence"}),
("encoder_hidden_states", {0: "batch", 1: "encoder_sequence"}),
("encoder_attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
Expand All @@ -316,9 +316,9 @@ def generate_dummy_inputs(
)
batch, encoder_seq_length = dummy_input["input_ids"].shape
encoder_hidden_states_shape = (batch, encoder_seq_length, self._config.hidden_size)
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
common_inputs["decoder_input_ids"] = dummy_input.pop("decoder_input_ids")
common_inputs["encoder_outputs"] = (torch.zeros(encoder_hidden_states_shape), None, None)
common_inputs["input_ids"] = dummy_input.pop("decoder_input_ids")
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
common_inputs["encoder_attention_mask"] = dummy_input.pop("attention_mask")

if "past_key_values" in dummy_input:
common_inputs["past_key_values"] = dummy_input.pop("past_key_values")
Expand All @@ -339,73 +339,6 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int
for i in range(num_decoder_layers * num_pkv_per_layer):
inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence}

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
reference_model_inputs["encoder_attention_mask"] = reference_model_inputs.pop("attention_mask")
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return reference_model_inputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": True}

return None


class SpeechSeq2SeqEncoderOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})

@property
def atol_for_validation(self) -> float:
return 1e-4


class SpeechSeq2SeqDecoderOnnxConfig(DecoderOnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "past_decoder_sequence + sequence"}),
("encoder_hidden_states", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")

return common_inputs

def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizerBase",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
common_inputs.pop("attention_mask")

return common_inputs

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
return reference_model_inputs


class OnnxSeq2SeqConfigWithPastAndLoss(DecoderOnnxConfig):
def __init__(self, config: DecoderOnnxConfig):
Expand Down
30 changes: 9 additions & 21 deletions optimum/onnx/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,15 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
):

if type(encoder_attention_mask) == torch.LongTensor:
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
return_dict=True,
use_cache=True,
)
else:
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
return_dict=True,
use_cache=True,
)

decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
return_dict=True,
use_cache=True,
)
last_hidden_state = decoder_outputs.last_hidden_state

if self.config.model_type == "t5" and self.config.tie_word_embeddings:
Expand Down

0 comments on commit 2d0a7ca

Please sign in to comment.