New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added onnx config whisper #19525
Added onnx config whisper #19525
Changes from 8 commits
3c7c712
da52a43
f61ae96
79ad032
8bd740b
ba50fa3
f9557dd
85fdb68
b1a2e4a
59a5965
0d1904c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -104,6 +104,7 @@ class OnnxConfig(ABC): | |||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}), | ||||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), | ||||||
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), | ||||||
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), | ||||||
} | ||||||
|
||||||
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): | ||||||
|
@@ -262,6 +263,19 @@ def _generate_dummy_images( | |||||
images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) | ||||||
return images | ||||||
|
||||||
def _generate_dummy_audio( | ||||||
self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220 | ||||||
): | ||||||
audio_data = [] | ||||||
for _ in range(batch_size): | ||||||
# time variable | ||||||
t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False) | ||||||
|
||||||
# generate pure sine wave at `frequency` Hz | ||||||
audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t)) | ||||||
|
||||||
return audio_data | ||||||
|
||||||
def generate_dummy_inputs( | ||||||
self, | ||||||
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], | ||||||
|
@@ -273,6 +287,9 @@ def generate_dummy_inputs( | |||||
num_channels: int = 3, | ||||||
image_width: int = 40, | ||||||
image_height: int = 40, | ||||||
sampling_rate: int = 22050, | ||||||
time_duration: float = 5.0, | ||||||
frequency: int = 220, | ||||||
tokenizer: "PreTrainedTokenizerBase" = None, | ||||||
) -> Mapping[str, Any]: | ||||||
""" | ||||||
|
@@ -297,6 +314,12 @@ def generate_dummy_inputs( | |||||
The width of the generated images. | ||||||
image_height (`int`, *optional*, defaults to 40): | ||||||
The height of the generated images. | ||||||
sampling_rate (`int`, *optional* defaults to 22050) | ||||||
The sampling rate for audio data generation. | ||||||
time_duration (`int`, *optional* defaults to 5 sec) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Let's be consistent with the signature! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
Total seconds of sampling for audio data generation. | ||||||
frequency (`int`, *optional* defaults to 220) | ||||||
The desired natural frequency of generated audio. | ||||||
|
||||||
Returns: | ||||||
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function | ||||||
|
@@ -325,7 +348,8 @@ def generate_dummy_inputs( | |||||
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add | ||||||
) | ||||||
# Generate dummy inputs according to compute batch and sequence | ||||||
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size | ||||||
input_token = preprocessor.unk_token if preprocessor.unk_token else "0" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We don't rely on Python bool magic conversion in the library, so let's test explicitly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, Added explicit tests for None and empty string |
||||||
dummy_input = [" ".join([input_token]) * seq_length] * batch_size | ||||||
if self.task == "multiple-choice": | ||||||
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations | ||||||
# made by ONNX | ||||||
|
@@ -345,11 +369,31 @@ def generate_dummy_inputs( | |||||
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) | ||||||
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) | ||||||
return dict(preprocessor(images=dummy_input, return_tensors=framework)) | ||||||
elif ( | ||||||
isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features" | ||||||
): | ||||||
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX | ||||||
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) | ||||||
dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency) | ||||||
return dict(preprocessor(dummy_input, return_tensors=framework)) | ||||||
else: | ||||||
raise ValueError( | ||||||
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." | ||||||
) | ||||||
|
||||||
def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken, this function doesn't do anything - why do we need it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @lewtun , this is true that this function does not do anything for existing models. However, this function can be overridden in some cases where the model inputs and ONNX Runtime inputs are different. This was needed when exporting the decoder in encoder-decoder models using encoder_outputs. For example: Check the following Decoder config in optimum where I am using the function DecoderOnnxConfig. Since we are waiting for the optimum PR to merge and migrating the changes this is no longer needed in the current merge. But we still need to update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks for the clarification! Looking at the (I'm happy to include this function as is, just trying to understand if we really need it or not) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function (at least in this case) update the existing keys in the input dict, but the values remains the same. For exporting these models we would require 2 different input sets. This is because the I think creating a separate function is a much cleaner way. But I am open to any suggestions. We can add the logic in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK great, now I understand well why we need this - thanks. IMO it's fine to include this function in this PR if we add a small note in the docstring like "Override to run inference with seq2seq models which have the encoder and decoder exported as separate ONNX files." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
""" | ||||||
Generate inputs for onnxruntime using the reference model inputs. | ||||||
mht-sharma marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Args: | ||||||
reference_model_inputs: ([`Mapping[str, Tensor]`): | ||||||
mht-sharma marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Reference inputs for the model. | ||||||
|
||||||
Returns: | ||||||
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function | ||||||
mht-sharma marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
return reference_model_inputs | ||||||
|
||||||
def patch_ops(self): | ||||||
for spec in self._patching_specs: | ||||||
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put the right module here :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated