Skip to content

Commit

Permalink
Updated iobinding
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Nov 15, 2022
1 parent 512c781 commit 14358a0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
7 changes: 3 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def is_torch_support_available(self) -> bool:
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
"""
Dictionary of keys to update the ONNX input name for export. Override the function when
the torch input names and the ONNX input names need to be different.
the dummy input names and the exported ONNX input names need to be different.
Returns:
`Mapping[str, str]`: A dictionary specifying the torch to ONNX input name map.
`Mapping[str, str]`: A dictionary specifying the dummy input name to exported ONNX input name map.
"""
return {}

Expand All @@ -230,7 +230,6 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
`Mapping[str, Mappingp[int, str]]`: The properly ordered inputs.
"""
inputs = self.inputs
torch_to_onnx_input_map = self.torch_to_onnx_input_map

ordered_inputs = {}
sig = inspect.signature(model.forward)
Expand All @@ -243,7 +242,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
# TODO: figure out a smart way of re-ordering potential nested structures.
# to_insert = sorted(to_insert, key=lambda t: t[0])
for name, dynamic_axes in to_insert:
name = torch_to_onnx_input_map[name] if name in torch_to_onnx_input_map else name
name = self.torch_to_onnx_input_map.get(name, name)
ordered_inputs[name] = dynamic_axes
return ordered_inputs

Expand Down
24 changes: 17 additions & 7 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
`(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`.
"""

SPEECH_SEQ2SEQ_ENCODER_INPUTS_DOCSTRING = r"""
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
Arguments:
input_features (`torch.FloatTensor`):
Mel features extracted from the raw speech waveform. `(batch_size, feature_size, encoder_sequence_length)`.
Expand Down Expand Up @@ -761,9 +761,9 @@ def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


class ORTEncoderForSpeechSeq2Seq(ORTEncoder):
class ORTEncoderForWhisper(ORTEncoder):
"""
Encoder model for ONNX Runtime inference for SpeechSeq2Seq models.
Encoder model for ONNX Runtime inference for Whisper model.
Arguments:
session (`onnxruntime.InferenceSession`):
Expand Down Expand Up @@ -804,7 +804,7 @@ def prepare_io_binding(

return io_binding, output_shapes, output_buffers

@add_start_docstrings_to_model_forward(SPEECH_SEQ2SEQ_ENCODER_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
def forward(
self,
input_features: torch.FloatTensor,
Expand Down Expand Up @@ -1273,6 +1273,10 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin
auto_model_class = AutoModelForSpeechSeq2Seq
main_input_name = "input_features"

_MODEL_TYPE_TO_ORTENCODER = {
"whisper": ORTEncoderForWhisper,
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand All @@ -1282,8 +1286,14 @@ def _initialize_encoder(
config: transformers.PretrainedConfig,
device: torch.device,
use_io_binding: bool = True,
) -> ORTEncoderForSpeechSeq2Seq:
return ORTEncoderForSpeechSeq2Seq(
) -> ORTEncoder:
if config.model_type not in self._MODEL_TYPE_TO_ORTENCODER:
raise KeyError(
f"{config.model_type} is not supported yet. "
f"Only {list(self._MODEL_TYPE_TO_ORTENCODER.keys())} are supported. "
f"If you want to support {config.model_type} please propose a PR or open up an issue."
)
return self._MODEL_TYPE_TO_ORTENCODER[config.model_type](
session=session,
config=config,
device=device,
Expand Down Expand Up @@ -1364,7 +1374,7 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
}

def get_encoder(self) -> ORTEncoderForSpeechSeq2Seq:
def get_encoder(self) -> ORTEncoder:
return self.encoder

# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
Expand Down

0 comments on commit 14358a0

Please sign in to comment.