Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ort export in exporters for encoder-decoder models #497

Conversation

mht-sharma
Copy link
Contributor

@mht-sharma mht-sharma commented Nov 21, 2022

What does this PR do?

Adds support to export the encoder and decoder separately for encoder-decoder models using exporters cli based on cmd line arguments.

Fixes #496

Example usage:

python -m optimum.exporters.onnx --model="openai/whisper-tiny.en" --task=speech2seq-lm-with-past output_whisper  --for-ort

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -303,6 +303,56 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion with lewtun regarding the use of the function. huggingface/transformers#19525 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed to generate the inputs for the separate encoder and decoder models?
Is it only used for validation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed only for validation using onnxruntime. Since the onnx model and torch model will have different input signatures when using encoder_outputs for exporting the model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling it generate_dummy_inputs_for_validation?
I think it can be misleading otherwise (more so now that we use a for-ort argument that does not mean the same thing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Updated!

@mht-sharma mht-sharma marked this pull request as ready for review November 22, 2022 08:14
@mht-sharma
Copy link
Contributor Author

mht-sharma commented Nov 22, 2022

Currently only Whisper is exported for ORT in this PR. Should I add the support for all the models or atleast a few more (example T5) in this PR. Or they could have their own PRs? @michaelbenayoun WDYT.

optimum/exporters/onnx/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/__main__.py Outdated Show resolved Hide resolved
@@ -303,6 +303,56 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed to generate the inputs for the separate encoder and decoder models?
Is it only used for validation?

optimum/exporters/onnx/base.py Outdated Show resolved Hide resolved
tests/exporters/test_onnx_export.py Outdated Show resolved Hide resolved
@michaelbenayoun
Copy link
Member

To answer your question: you can also add support for other models if you feel like it!

@mht-sharma mht-sharma force-pushed the export-encoder-decoder-separately-seq2seq-models branch from 28dc7d0 to fc97e78 Compare November 28, 2022 09:12
@mht-sharma
Copy link
Contributor Author

To answer your question: you can also add support for other models if you feel like it!

Added support for exporting the Seq2Seq-lm models also. AFAIK I covered the models with existing support let me know if I missed something.

Comment on lines 72 to 79
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the name for-ort be misleading as the models can have other tasks apart from the conditional generation?
I have mentioned Useful for encoder-decoder models for conditional generation in help. But not sure if this would be enough.

Probably updating ORTModelXXX -> ORTModelForConditionalGeneration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just say to run with optimum.onnxruntime. I think for-ort is good enough, or at least I do not have a better naming in mind.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really great!
Left a few comments and questions, but huge work @mht-sharma !

Comment on lines 72 to 79
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just say to run with optimum.onnxruntime. I think for-ort is good enough, or at least I do not have a better naming in mind.

args.opset,
args.output,
)
use_past = True if "-with-past" in task else False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not need that since it is already in the onnx config no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was probably thinking of how the ORTModelForConditionalGeneration is currently implemented which create onnx config always without past. But that can be modified later.

Currently, updated to use use_past from the onnx config.

@@ -303,6 +303,56 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) ->
"""
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))}

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling it generate_dummy_inputs_for_validation?
I think it can be misleading otherwise (more so now that we use a for-ort argument that does not mean the same thing)

optimum/exporters/onnx/base.py Show resolved Hide resolved
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@abstractmethod

optimum/exporters/onnx/convert.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/convert.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/model_configs.py Show resolved Hide resolved
optimum/exporters/onnx/utils.py Outdated Show resolved Hide resolved
optimum/utils/input_generators.py Outdated Show resolved Hide resolved
@JingyaHuang JingyaHuang mentioned this pull request Nov 29, 2022
4 tasks
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, love the consistency of new ONNX configs.

@mht-sharma mht-sharma merged commit 6ee424b into huggingface:main Nov 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Export encoder and decoder in the encoder-decoder models separately using exporters
4 participants