Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx config whisper #19525

Merged
merged 11 commits into from Nov 1, 2022
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Expand Up @@ -99,6 +99,7 @@ Ready-made configurations include the following architectures:
- Table Transformer
- Vision Encoder decoder
- ViT
- Whisper
- XLM
- XLM-RoBERTa
- XLM-RoBERTa-XL
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/whisper/__init__.py
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig"],
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig", "WhisperOnnxConfig"],
"feature_extraction_whisper": ["WhisperFeatureExtractor"],
"processing_whisper": ["WhisperProcessor"],
"tokenization_whisper": ["WhisperTokenizer"],
Expand Down Expand Up @@ -55,7 +55,7 @@
]

if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor
from .processing_whisper import WhisperProcessor
from .tokenization_whisper import WhisperTokenizer
Expand Down
64 changes: 64 additions & 0 deletions src/transformers/models/whisper/configuration_whisper.py
Expand Up @@ -14,10 +14,18 @@
# limitations under the License.
""" Whisper model configuration"""

from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
from ...utils import logging


if TYPE_CHECKING:
from ... import PreTrainedTokenizerBase, TensorType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put the right module here :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

from ...feature_extraction_utils import FeatureExtractionMixin

logger = logging.get_logger(__name__)

WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
Expand Down Expand Up @@ -214,3 +222,59 @@ def __init__(
begin_suppress_tokens=begin_suppress_tokens,
**kwargs,
)


class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")

return common_inputs

def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
) -> Mapping[str, Any]:
dummy_inputs = OrderedDict()
encoder_inputs = OnnxConfig.generate_dummy_inputs(
self,
preprocessor=preprocessor.feature_extractor,
batch_size=batch_size,
framework=framework,
sampling_rate=sampling_rate,
time_duration=time_duration,
frequency=frequency,
)
decoder_inputs = super().generate_dummy_inputs(
preprocessor.tokenizer, batch_size, seq_length, is_pair, framework
)

dummy_inputs["input_features"] = encoder_inputs.pop("input_features")
dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids")

if "past_key_values" in decoder_inputs:
dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values")

return dummy_inputs

@property
def atol_for_validation(self) -> float:
return 1e-3
46 changes: 45 additions & 1 deletion src/transformers/onnx/config.py
Expand Up @@ -104,6 +104,7 @@ class OnnxConfig(ABC):
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}

def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
Expand Down Expand Up @@ -262,6 +263,19 @@ def _generate_dummy_images(
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def _generate_dummy_audio(
self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220
):
audio_data = []
for _ in range(batch_size):
# time variable
t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False)

# generate pure sine wave at `frequency` Hz
audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t))

return audio_data

def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
Expand All @@ -273,6 +287,9 @@ def generate_dummy_inputs(
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
tokenizer: "PreTrainedTokenizerBase" = None,
) -> Mapping[str, Any]:
"""
Expand All @@ -297,6 +314,12 @@ def generate_dummy_inputs(
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
sampling_rate (`int`, *optional* defaults to 22050)
The sampling rate for audio data generation.
time_duration (`int`, *optional* defaults to 5 sec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time_duration (`int`, *optional* defaults to 5 sec)
time_duration (`float`, *optional* defaults to 5.0)

Let's be consistent with the signature!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Total seconds of sampling for audio data generation.
frequency (`int`, *optional* defaults to 220)
The desired natural frequency of generated audio.

Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
Expand Down Expand Up @@ -325,7 +348,8 @@ def generate_dummy_inputs(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
input_token = preprocessor.unk_token if preprocessor.unk_token else "0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_token = preprocessor.unk_token if preprocessor.unk_token else "0"
input_token = preprocessor.unk_token if preprocessor.unk_token is not None else "0"

We don't rely on Python bool magic conversion in the library, so let's test explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Added explicit tests for None and empty string

dummy_input = [" ".join([input_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
Expand All @@ -345,11 +369,31 @@ def generate_dummy_inputs(
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
return dict(preprocessor(images=dummy_input, return_tensors=framework))
elif (
isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features"
):
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency)
return dict(preprocessor(dummy_input, return_tensors=framework))
else:
raise ValueError(
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
)

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, this function doesn't do anything - why do we need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lewtun , this is true that this function does not do anything for existing models. However, this function can be overridden in some cases where the model inputs and ONNX Runtime inputs are different. This was needed when exporting the decoder in encoder-decoder models using encoder_outputs. For example: Check the following Decoder config in optimum where I am using the function DecoderOnnxConfig.

Since we are waiting for the optimum PR to merge and migrating the changes this is no longer needed in the current merge.

But we still need to update the VisionEncoderDecoderConfig using encoder_outputs in that case we would need such function. Should we merge with this and utilise this in new VisionEncoderDecoder PR (In case the user wants to create the PR for this change, this option would become easier for him)? Or we could remove it from here and add the function along with the new PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the clarification! Looking at the optimum code, it seems like we use this function to add new fields to the ORT inputs - is there a reason we can't capture that logic in a single generate_dummy_inputs() function associated with the ONNX config?

(I'm happy to include this function as is, just trying to understand if we really need it or not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function (at least in this case) update the existing keys in the input dict, but the values remains the same.

For exporting these models we would require 2 different input sets. This is because the model.forward() has a different input signature (since this is the full model before export) that the generated ONNX model (only decoder is exported). Therefore we need a way to alter the existing model inputs to run inference with both models in validate_model_outputs.

I think creating a separate function is a much cleaner way. But I am open to any suggestions. We can add the logic in the generate_dummy_inputs by adding a new argument to the function, maybe called ort_inputs=True/False, and for these models we return different sets of inputs. But this will require updation of all the OnnxConfigs having this function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK great, now I understand well why we need this - thanks. IMO it's fine to include this function in this PR if we add a small note in the docstring like "Override to run inference with seq2seq models which have the encoder and decoder exported as separate ONNX files."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""
Generate inputs for onnxruntime using the reference model inputs.
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved

Args:
reference_model_inputs: ([`Mapping[str, Tensor]`):
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
Reference inputs for the model.

Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
"""
return reference_model_inputs

def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/onnx/convert.py
Expand Up @@ -145,7 +145,21 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
model_inputs_device = dict()
for k, v in model_inputs.items():
if isinstance(v, Tuple):
model_inputs_device[k] = tuple(
x.to(device) if isinstance(x, torch.Tensor) else None for x in v
)
elif isinstance(v, List):
model_inputs_device[k] = [
tuple(x.to(device) if isinstance(x, torch.Tensor) else None for x in t) for t in v
]
else:
model_inputs_device[k] = v.to(device)

model_inputs = model_inputs_device

inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

Expand Down Expand Up @@ -404,9 +418,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value

# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs)

# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_onnxruntime.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/onnx/features.py
Expand Up @@ -29,6 +29,7 @@
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
)
Expand Down Expand Up @@ -100,6 +101,7 @@ class FeaturesManager:
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"vision2seq-lm": AutoModelForVision2Seq,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -489,6 +491,13 @@ class FeaturesManager:
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
),
"whisper": supported_features_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx_config_cls="models.whisper.WhisperOnnxConfig",
),
"xlm": supported_features_mapping(
"default",
"masked-lm",
Expand Down
3 changes: 2 additions & 1 deletion tests/onnx/test_onnx_v2.py
Expand Up @@ -217,6 +217,7 @@ def test_values_override(self):
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
("whisper", "openai/whisper-tiny.en"),
}

PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
Expand Down Expand Up @@ -397,7 +398,7 @@ def _onnx_export_encoder_decoder_models(
preprocessor = AutoTokenizer.from_pretrained(model_name)

with NamedTemporaryFile("w") as decoder_output:
onnx_inputs, onnx_outputs = export(
_, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
Expand Down