Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ort export in exporters for encoder-decoder models #497

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@

from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa
from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa
from .convert import export, validate_model_outputs # noqa
from .convert import ( # noqa
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)
54 changes: 44 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from ...utils import logging
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import export, validate_model_outputs
from .convert import (
export,
export_encoder_decoder_model,
validate_encoder_decoder_model_outputs,
validate_model_outputs,
)


logger = logging.get_logger() # pylint: disable=invalid-name
Expand Down Expand Up @@ -64,6 +69,14 @@ def main():
),
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the name for-ort be misleading as the models can have other tasks apart from the conditional generation?
I have mentioned Useful for encoder-decoder models for conditional generation in help. But not sure if this would be enough.

Probably updating ORTModelXXX -> ORTModelForConditionalGeneration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just say to run with optimum.onnxruntime. I think for-ort is good enough, or at least I do not have a better naming in mind.

parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.")

# Retrieve CLI arguments
Expand Down Expand Up @@ -115,12 +128,20 @@ def main():
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)

onnx_inputs, onnx_outputs = export(
model,
onnx_config,
args.opset,
args.output,
)
use_past = True if "-with-past" in task else False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not need that since it is already in the onnx config no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was probably thinking of how the ORTModelForConditionalGeneration is currently implemented which create onnx config always without past. But that can be modified later.

Currently, updated to use use_past from the onnx config.

if model.config.is_encoder_decoder and args.for_ort:
onnx_inputs, onnx_outputs = export_encoder_decoder_model(
model,
onnx_config,
args.opset,
task,
use_past,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output)

# Saving the model config as this is needed sometimes.
model.config.save_pretrained(args.output.parent)
Expand All @@ -144,11 +165,24 @@ def main():
args.atol = args.atol[task.replace("-with-past", "")]

try:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
if model.config.is_encoder_decoder and args.for_ort:
validate_encoder_decoder_model_outputs(
onnx_config,
model,
onnx_outputs,
args.atol,
task,
use_past,
args.output.parent.joinpath("encoder_model.onnx"),
args.output.parent.joinpath("decoder_model.onnx"),
args.output.parent.joinpath("decoder_with_past_model.onnx"),
)
else:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
except ValueError:
logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}")
logger.error(f"An error occured, but the model was saved at: {args.output.parent.as_posix()}")
return
logger.info(f"All good, model saved at: {args.output.as_posix()}")
logger.info(f"All good, model saved at: {args.output.parent.as_posix()}")


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion with lewtun regarding the use of the function. huggingface/transformers#19525 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed to generate the inputs for the separate encoder and decoder models?
Is it only used for validation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed only for validation using onnxruntime. Since the onnx model and torch model will have different input signatures when using encoder_outputs for exporting the model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling it generate_dummy_inputs_for_validation?
I think it can be misleading otherwise (more so now that we use a for-ort argument that does not mean the same thing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Updated!

"""
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.
Args:
reference_model_inputs ([`Mapping[str, Tensor]`):
Reference inputs for the model.
Returns:
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs


class OnnxConfigWithPast(OnnxConfig, ABC):
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True
Expand Down Expand Up @@ -454,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig:
"""
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder
of the model separately.

Args:
config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.

Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@abstractmethod

def get_decoder_onnx_config(
self, config: "PretrainedConfig", task: str = "default", use_past: bool = False
) -> OnnxConfig:
"""
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder
of the model separately.

Args:
config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX.
task (`str`, defaults to `"default"`):
The task the model should be exported for.
use_past (`bool`, defaults to `False`):
Whether to export the model with past_key_values.

Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} decoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)
133 changes: 130 additions & 3 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ...utils import logging
from .base import OnnxConfig
from .utils import MIN_TORCH_VERSION, is_torch_onnx_support_available
from .utils import MIN_TORCH_VERSION, get_encoder_decoder_models_for_export, is_torch_onnx_support_available


if is_torch_available():
Expand Down Expand Up @@ -61,6 +61,66 @@ def check_dummy_inputs_are_allowed(
)


def validate_encoder_decoder_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
onnx_named_outputs: List[str],
atol: float,
task: str,
use_past: bool,
encoder_onnx_model: Path,
decoder_onnx_model: Path,
decoder_with_past_onnx_model: Path = None,
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
The following method validates the ONNX models exported using the `export_encoder_decoder_model` method.

Args:
config ([`~OnnxConfig`]:
The configuration used to export the model.
reference_model ([`~PreTrainedModel`] or [`~TFPreTrainedModel`]):
The model used for the export.
onnx_named_outputs (`List[str]`):
The names of the outputs to check.
atol (`float`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
task (`str`)
The type of task to export the model with.
use_past (`bool`, *optional*, defaults to `None`):
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
Whether to export the model with past_key_values.
encoder_onnx_model (`Path`):
The path to the exported encoder ONNX model.
decoder_onnx_model (`Path`):
The path to the exported decoder ONNX model.
decoder_with_past_onnx_model (`Path`, *optional*, defaults to `None`):
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
The path to the exported decoder with past ONNX model. Required when `use_past` is True.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
task = task.replace("-with-past", "")

models_for_validation = get_encoder_decoder_models_for_export(reference_model, config, task, use_past)

if len(onnx_named_outputs) != len(models_for_validation.keys()):
raise ValueError(
f"Invalid number of ONNX named outputs. Required {len(models_for_validation.keys())}, Provided {len(onnx_named_outputs)}"
)

# Validate encoder
model, onnx_config = models_for_validation["encoder"]
validate_model_outputs(onnx_config, model, encoder_onnx_model, onnx_named_outputs[0], atol)

# Validate decoder
model, onnx_config = models_for_validation["decoder"]
validate_model_outputs(onnx_config, model, decoder_onnx_model, onnx_named_outputs[1], atol)

if use_past:
# Validate decoder with past
model, onnx_config = models_for_validation["decoder_with_past"]
validate_model_outputs(onnx_config, model, decoder_with_past_onnx_model, onnx_named_outputs[2], atol)


def validate_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
Expand Down Expand Up @@ -115,9 +175,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value

# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs)

# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_onnxruntime.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
Expand Down Expand Up @@ -223,7 +286,9 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else None, dummy_inputs
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -321,6 +386,68 @@ def export_tensorflow(
return input_names, output_names


def export_encoder_decoder_model(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
opset: int,
task: str,
use_past: bool,
encoder_output: Path,
decoder_output: Path,
decoder_with_past_output: Path = None,
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
device: str = "cpu",
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
The following method exports the encoder and decoder components of the model as separate
ONNX files.

Args:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
task (`str`)
The type of task to export the model with.
use_past (`bool`):
Whether to export the model with past_key_values.
encoder_output (`Path`):
Directory to store the exported encoder ONNX model.
decoder_output (`Path`):
Directory to store the exported decoder ONNX model.
decoder_with_past_output (`Path`, *optional*, defaults to `None`):
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
Directory to store the exported decoder with past ONNX model. Required when `use_past` is True.
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
inputs from the ONNX configuration.
"""
task = task.replace("-with-past", "")

models_for_export = get_encoder_decoder_models_for_export(model, config, task, use_past)
outputs = []

# export encoder
model, onnx_config = models_for_export["encoder"]
outputs.append(export(model, onnx_config, opset, encoder_output, device=device))

# export decoder
model, onnx_config = models_for_export["decoder"]
outputs.append(export(model, onnx_config, opset, decoder_output, device=device))

if use_past:
# export decoder with past
model, onnx_config = models_for_export["decoder_with_past"]
outputs.append(export(model, onnx_config, opset, decoder_with_past_output, device=device))

outputs = list(map(list, zip(*outputs)))
return outputs


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
Expand Down