Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ort export in exporters for encoder-decoder models #497

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@

from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa
from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa
from .convert import export, validate_model_outputs # noqa
from .convert import ( # noqa
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)
49 changes: 39 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from ...utils import logging
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import export, validate_model_outputs
from .convert import (
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)


logger = logging.get_logger() # pylint: disable=invalid-name
Expand Down Expand Up @@ -64,6 +69,14 @@ def main():
),
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.")

# Retrieve CLI arguments
Expand Down Expand Up @@ -115,12 +128,17 @@ def main():
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)

onnx_inputs, onnx_outputs = export(
model,
onnx_config,
args.opset,
args.output,
)
if model.config.is_encoder_decoder and args.for_ort:
onnx_inputs, onnx_outputs = export_encoder_decoder_model(
model,
onnx_config,
args.opset,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output)

# Saving the model config as this is needed sometimes.
model.config.save_pretrained(args.output.parent)
Expand All @@ -144,11 +162,22 @@ def main():
args.atol = args.atol[task.replace("-with-past", "")]

try:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
if model.config.is_encoder_decoder and args.for_ort:
validate_encoder_decoder_model_outputs(
onnx_config,
model,
onnx_outputs,
args.atol,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
except ValueError:
logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}")
logger.error(f"An error occured, but the model was saved at: {args.output.parent.as_posix()}")
return
logger.info(f"All good, model saved at: {args.output.as_posix()}")
logger.info(f"All good, model saved at: {args.output.parent.as_posix()}")


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.
Args:
reference_model_inputs ([`Mapping[str, Tensor]`):
Reference inputs for the model.
Returns:
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs


class OnnxConfigWithPast(OnnxConfig, ABC):
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True
Expand Down Expand Up @@ -454,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig:
"""
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder
of the model separately.

Args:
config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.

Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@abstractmethod

def get_decoder_onnx_config(
self, config: "PretrainedConfig", task: str = "default", use_past: bool = False
) -> OnnxConfig:
"""
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder
of the model separately.

Args:
config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX.
task (`str`, defaults to `"default"`):
The task the model should be exported for.
use_past (`bool`, defaults to `False`):
Whether to export the model with past_key_values.

Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} decoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)
119 changes: 115 additions & 4 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_tf_available, is_torch_available

from ...utils import logging
from .base import OnnxConfig
from .utils import MIN_TORCH_VERSION, is_torch_onnx_support_available
from .utils import MIN_TORCH_VERSION, get_encoder_decoder_models_for_export, is_torch_onnx_support_available


if is_torch_available():
Expand Down Expand Up @@ -61,6 +61,58 @@ def check_dummy_inputs_are_allowed(
)


def validate_encoder_decoder_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
onnx_named_outputs: List[str],
atol: float,
encoder_onnx_model: Path,
decoder_onnx_model: Path,
decoder_with_past_onnx_model: Optional[Path] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
The following method validates the ONNX models exported using the `export_encoder_decoder_model` method.

Args:
config ([`~OnnxConfig`]:
The configuration used to export the model.
reference_model ([`~PreTrainedModel`] or [`~TFPreTrainedModel`]):
The model used for the export.
onnx_named_outputs (`List[str]`):
The names of the outputs to check.
atol (`float`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
encoder_onnx_model (`Path`):
The path to the exported encoder ONNX model.
decoder_onnx_model (`Path`):
The path to the exported decoder ONNX model.
decoder_with_past_onnx_model (`Optional[Path]`, defaults to `None`):
The path to the exported decoder with past ONNX model. Required when `past_key_values` are exported.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
models_for_validation = get_encoder_decoder_models_for_export(reference_model, config)

if len(onnx_named_outputs) != len(models_for_validation.keys()):
raise ValueError(
f"Invalid number of ONNX named outputs. Required {len(models_for_validation.keys())}, Provided {len(onnx_named_outputs)}"
)

# Validate encoder
model, onnx_config = models_for_validation["encoder"]
validate_model_outputs(onnx_config, model, encoder_onnx_model, onnx_named_outputs[0], atol)

# Validate decoder
model, onnx_config = models_for_validation["decoder"]
validate_model_outputs(onnx_config, model, decoder_onnx_model, onnx_named_outputs[1], atol)

if config.use_past:
# Validate decoder with past
model, onnx_config = models_for_validation["decoder_with_past"]
validate_model_outputs(onnx_config, model, decoder_with_past_onnx_model, onnx_named_outputs[2], atol)


def validate_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
Expand Down Expand Up @@ -115,9 +167,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value

# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_for_validation = config.generate_dummy_inputs_for_validation(reference_model_inputs)

# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_for_validation.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
Expand Down Expand Up @@ -223,7 +278,9 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -321,6 +378,60 @@ def export_tensorflow(
return input_names, output_names


def export_encoder_decoder_model(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
opset: int,
encoder_output: Path,
decoder_output: Path,
decoder_with_past_output: Optional[Path] = None,
device: str = "cpu",
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
The following method exports the encoder and decoder components of the model as separate
ONNX files.

Args:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
encoder_output (`Path`):
Directory to store the exported encoder ONNX model.
decoder_output (`Path`):
Directory to store the exported decoder ONNX model.
decoder_with_past_output (`Optional[Path]`, defaults to `None`):
Directory to store the exported decoder with past ONNX model. Required when `past_key_values` are exported.
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
inputs from the ONNX configuration.
"""
models_for_export = get_encoder_decoder_models_for_export(model, config)
outputs = []

# export encoder
model, onnx_config = models_for_export["encoder"]
outputs.append(export(model, onnx_config, opset, encoder_output, device=device))

# export decoder
model, onnx_config = models_for_export["decoder"]
outputs.append(export(model, onnx_config, opset, decoder_output, device=device))

if config.use_past:
# export decoder with past
model, onnx_config = models_for_export["decoder_with_past"]
outputs.append(export(model, onnx_config, opset, decoder_with_past_output, device=device))

outputs = list(map(list, zip(*outputs)))
return outputs


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
Expand Down