Skip to content

Commit

Permalink
Updated config location
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Nov 28, 2022
1 parent bbb89e1 commit fc97e78
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 39 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
"--for-ort",
action="store_true",
help=(
"This generates ONNX models to run inference with ONNX Runtime ORTModelXXX for encoder-decoder models."
"This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models."
" If enabled the encoder and decoder of the model are exported separately."
),
)
Expand Down
78 changes: 40 additions & 38 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,44 +315,6 @@ def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str,
"""
return reference_model_inputs

def get_encoder_onnx_config(self, config: "PretrainedConfig"):
"""
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder
of the model separately.
Args:
config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)

def get_decoder_onnx_config(self, config: "PretrainedConfig", task: str = "default", use_past: bool = False):
"""
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder
of the model separately.
Args:
config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX.
task (`str`, defaults to `"default"`):
The task the model should be exported for.
use_past (`bool`, defaults to `False`):
Whether to export the model with past_key_values.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} decoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)


class OnnxConfigWithPast(OnnxConfig, ABC):
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True
Expand Down Expand Up @@ -504,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig:
"""
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder
of the model separately.
Args:
config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} encoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)

def get_decoder_onnx_config(
self, config: "PretrainedConfig", task: str = "default", use_past: bool = False
) -> OnnxConfig:
"""
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder
of the model separately.
Args:
config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX.
task (`str`, defaults to `"default"`):
The task the model should be exported for.
use_past (`bool`, defaults to `False`):
Whether to export the model with past_key_values.
Returns:
`OnnxConfig`: An instance of the ONNX configuration object.
"""
raise NotImplementedError(
f"{config.model_type} decoder export is not supported yet. ",
f"If you want to support {config.model_type} please propose a PR or open up an issue.",
)

0 comments on commit fc97e78

Please sign in to comment.