Skip to content

Commit

Permalink
Add ORT export for encoder-decoder models in exporters (#497)
Browse files Browse the repository at this point in the history
* Add ort export in exporters for encoder-decoder models

* Updated error docstring

* Update encoder decoder config location

* Added tests

* Update arguments help

* Updated docstring and removed redundant code

* Updated config location

* Added methods for ncoder/decoder onnx export and validation

* Added Seq2Seq-lm encoder-decoder configs

* Uncommented tests

* Fixed test

* Updated argument help

* Removed use-past and updated docstring

* Updated input generator Seq2SeqDecoderConfig

* Update docstrings to use Optional

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Added optional import

* Remove reduntant task from the export function

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
  • Loading branch information
mht-sharma and michaelbenayoun committed Nov 30, 2022
1 parent 31eb67c commit 6ee424b
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 73 deletions.
7 changes: 6 additions & 1 deletion optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@

from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa
from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa
from .convert import export, validate_model_outputs # noqa
from .convert import ( # noqa
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)
49 changes: 39 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from ...utils import logging
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import export, validate_model_outputs
from .convert import (
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)


logger = logging.get_logger() # pylint: disable=invalid-name
Expand Down Expand Up @@ -64,6 +69,14 @@ def main():
),
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.")

# Retrieve CLI arguments
Expand Down Expand Up @@ -115,12 +128,17 @@ def main():
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)

onnx_inputs, onnx_outputs = export(
model,
onnx_config,
args.opset,
args.output,
)
if model.config.is_encoder_decoder and args.for_ort:
onnx_inputs, onnx_outputs = export_encoder_decoder_model(
model,
onnx_config,
args.opset,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output)

# Saving the model config as this is needed sometimes.
model.config.save_pretrained(args.output.parent)
Expand All @@ -144,11 +162,22 @@ def main():
args.atol = args.atol[task.replace("-with-past", "")]

try:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
if model.config.is_encoder_decoder and args.for_ort:
validate_encoder_decoder_model_outputs(
onnx_config,
model,
onnx_outputs,
args.atol,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
except ValueError:
logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}")
logger.error(f"An error occured, but the model was saved at: {args.output.parent.as_posix()}")
return
logger.info(f"All good, model saved at: {args.output.as_posix()}")
logger.info(f"All good, model saved at: {args.output.parent.as_posix()}")


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.
Args:
reference_model_inputs ([`Mapping[str, Tensor]`):
Reference inputs for the model.
Returns:
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs


class OnnxConfigWithPast(OnnxConfig, ABC):
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True
Expand Down Expand Up @@ -454,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig:
"""
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder
of the model separately.
Args:
config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)

def get_decoder_onnx_config(
self, config: "PretrainedConfig", task: str = "default", use_past: bool = False
) -> OnnxConfig:
"""
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder
of the model separately.
Args:
config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX.
task (`str`, defaults to `"default"`):
The task the model should be exported for.
use_past (`bool`, defaults to `False`):
Whether to export the model with past_key_values.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} decoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)
119 changes: 115 additions & 4 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_tf_available, is_torch_available

from ...utils import logging
from .base import OnnxConfig
from .utils import MIN_TORCH_VERSION, is_torch_onnx_support_available
from .utils import MIN_TORCH_VERSION, get_encoder_decoder_models_for_export, is_torch_onnx_support_available


if is_torch_available():
Expand Down Expand Up @@ -61,6 +61,58 @@ def check_dummy_inputs_are_allowed(
)


def validate_encoder_decoder_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
onnx_named_outputs: List[str],
atol: float,
encoder_onnx_model: Path,
decoder_onnx_model: Path,
decoder_with_past_onnx_model: Optional[Path] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
The following method validates the ONNX models exported using the `export_encoder_decoder_model` method.
Args:
config ([`~OnnxConfig`]:
The configuration used to export the model.
reference_model ([`~PreTrainedModel`] or [`~TFPreTrainedModel`]):
The model used for the export.
onnx_named_outputs (`List[str]`):
The names of the outputs to check.
atol (`float`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
encoder_onnx_model (`Path`):
The path to the exported encoder ONNX model.
decoder_onnx_model (`Path`):
The path to the exported decoder ONNX model.
decoder_with_past_onnx_model (`Optional[Path]`, defaults to `None`):
The path to the exported decoder with past ONNX model. Required when `past_key_values` are exported.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
models_for_validation = get_encoder_decoder_models_for_export(reference_model, config)

if len(onnx_named_outputs) != len(models_for_validation.keys()):
raise ValueError(
f"Invalid number of ONNX named outputs. Required {len(models_for_validation.keys())}, Provided {len(onnx_named_outputs)}"
)

# Validate encoder
model, onnx_config = models_for_validation["encoder"]
validate_model_outputs(onnx_config, model, encoder_onnx_model, onnx_named_outputs[0], atol)

# Validate decoder
model, onnx_config = models_for_validation["decoder"]
validate_model_outputs(onnx_config, model, decoder_onnx_model, onnx_named_outputs[1], atol)

if config.use_past:
# Validate decoder with past
model, onnx_config = models_for_validation["decoder_with_past"]
validate_model_outputs(onnx_config, model, decoder_with_past_onnx_model, onnx_named_outputs[2], atol)


def validate_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
Expand Down Expand Up @@ -115,9 +167,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value

# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_for_validation = config.generate_dummy_inputs_for_validation(reference_model_inputs)

# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_for_validation.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
Expand Down Expand Up @@ -223,7 +278,9 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -321,6 +378,60 @@ def export_tensorflow(
return input_names, output_names


def export_encoder_decoder_model(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
opset: int,
encoder_output: Path,
decoder_output: Path,
decoder_with_past_output: Optional[Path] = None,
device: str = "cpu",
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
The following method exports the encoder and decoder components of the model as separate
ONNX files.
Args:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
encoder_output (`Path`):
Directory to store the exported encoder ONNX model.
decoder_output (`Path`):
Directory to store the exported decoder ONNX model.
decoder_with_past_output (`Optional[Path]`, defaults to `None`):
Directory to store the exported decoder with past ONNX model. Required when `past_key_values` are exported.
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
inputs from the ONNX configuration.
"""
models_for_export = get_encoder_decoder_models_for_export(model, config)
outputs = []

# export encoder
model, onnx_config = models_for_export["encoder"]
outputs.append(export(model, onnx_config, opset, encoder_output, device=device))

# export decoder
model, onnx_config = models_for_export["decoder"]
outputs.append(export(model, onnx_config, opset, decoder_output, device=device))

if config.use_past:
# export decoder with past
model, onnx_config = models_for_export["decoder_with_past"]
outputs.append(export(model, onnx_config, opset, decoder_with_past_output, device=device))

outputs = list(map(list, zip(*outputs)))
return outputs


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
Expand Down

0 comments on commit 6ee424b

Please sign in to comment.