Skip to content

Commit

Permalink
Add support ORT whisper (#420)
Browse files Browse the repository at this point in the history
* added support onnxruntime whisper

* Updated decoder export model

* Updated docstring

* updated tests for whisper

* add whisper onnx configs

* Added Whisper model to exporters

* added whisper to exporters
* Removed reduntant code
* Added io binding for ORTModelForSpeechSeq2Seq

* Removed unused imports

* Added tests for exporters and iobinding

* Removed redundant line

* Updated input generator and config

* Updatedtests

* added sample audio input

* Removed redundant code to fix test

* Updated iobinding

* Fix tests
  • Loading branch information
mht-sharma committed Nov 15, 2022
1 parent 71ed646 commit a29647e
Show file tree
Hide file tree
Showing 13 changed files with 751 additions and 54 deletions.
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC):
"seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}

def __init__(
Expand Down Expand Up @@ -206,6 +207,17 @@ def is_torch_support_available(self) -> bool:
return TORCH_VERSION >= self.MIN_TORCH_VERSION
return False

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
"""
Dictionary of keys to update the ONNX input name for export. Override the function when
the dummy input names and the exported ONNX input names need to be different.
Returns:
`Mapping[str, str]`: A dictionary specifying the dummy input name to exported ONNX input name map.
"""
return {}

def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, str]]:
"""
Re-orders the inputs using the model forward pass signature.
Expand All @@ -218,6 +230,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
`Mapping[str, Mappingp[int, str]]`: The properly ordered inputs.
"""
inputs = self.inputs

ordered_inputs = {}
sig = inspect.signature(model.forward)
for param in sig.parameters:
Expand All @@ -229,6 +242,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
# TODO: figure out a smart way of re-ordering potential nested structures.
# to_insert = sorted(to_insert, key=lambda t: t[0])
for name, dynamic_axes in to_insert:
name = self.torch_to_onnx_input_map.get(name, name)
ordered_inputs[name] = dynamic_axes
return ordered_inputs

Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Mapping

from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -99,3 +100,15 @@ class VisionOnnxConfig(OnnxConfig):

class TextAndVisionOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator)


class AudioOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator,)


class TextAndAudioOnnxConfig(Seq2SeqOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyAudioInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
1 change: 1 addition & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def export_pytorch(
else:
# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.

onnx_export(
model,
(dummy_inputs,),
Expand Down
81 changes: 80 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,26 @@
from ...utils import (
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
)
from .base import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import DecoderOnnxConfig, EncoderOnnxConfig, Seq2SeqOnnxConfig, TextAndVisionOnnxConfig, VisionOnnxConfig
from .config import (
AudioOnnxConfig,
DecoderOnnxConfig,
EncoderOnnxConfig,
Seq2SeqOnnxConfig,
TextAndAudioOnnxConfig,
TextAndVisionOnnxConfig,
VisionOnnxConfig,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -626,3 +636,72 @@ def generate_dummy_inputs(self, framework: str = "pt"):
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
return dummy_inputs


class WhisperOnnxConfig(TextAndAudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs


class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}


class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

DUMMY_INPUT_GENERATOR_CLASSES = (
DummySeq2SeqDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"decoder_input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"},
"encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"},
}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
return {"decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states"}

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": True}

return None
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
)
if is_tf_available():
Expand Down Expand Up @@ -122,6 +123,7 @@ class TasksManager:
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -506,6 +508,13 @@ class TasksManager:
onnx="T5OnnxConfig",
),
"vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"),
"whisper": supported_tasks_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx="WhisperOnnxConfig",
),
"xlm": supported_tasks_mapping(
"default",
"masked-lm",
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
],
"modeling_seq2seq": ["ORTModelForSeq2SeqLM"],
"modeling_seq2seq": ["ORTModelForSeq2SeqLM", "ORTModelForSpeechSeq2Seq"],
"optimization": ["ORTOptimizer"],
"quantization": ["ORTQuantizer"],
"trainer": ["ORTTrainer"],
Expand Down Expand Up @@ -68,7 +68,7 @@
ORTModelForSequenceClassification,
ORTModelForTokenClassification,
)
from .modeling_seq2seq import ORTModelForSeq2SeqLM
from .modeling_seq2seq import ORTModelForSeq2SeqLM, ORTModelForSpeechSeq2Seq
from .optimization import ORTOptimizer
from .quantization import ORTQuantizer
from .trainer import ORTTrainer
Expand Down

0 comments on commit a29647e

Please sign in to comment.