Skip to content

Commit

Permalink
Updated input generator and config
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Nov 11, 2022
1 parent d379719 commit af7ddd3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 70 deletions.
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ def is_torch_support_available(self) -> bool:
return TORCH_VERSION >= self.MIN_TORCH_VERSION
return False

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
"""
Dictionary of keys to update the ONNX input name for export. Override the function when
the torch input names and the ONNX input names need to be different.
Returns:
`Mapping[str, str]`: A dictionary specifying the torch to ONNX input name map.
"""
return {}

def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, str]]:
"""
Re-orders the inputs using the model forward pass signature.
Expand All @@ -219,6 +230,8 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
`Mapping[str, Mappingp[int, str]]`: The properly ordered inputs.
"""
inputs = self.inputs
torch_to_onnx_input_map = self.torch_to_onnx_input_map

ordered_inputs = {}
sig = inspect.signature(model.forward)
for param in sig.parameters:
Expand All @@ -230,6 +243,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
# TODO: figure out a smart way of re-ordering potential nested structures.
# to_insert = sorted(to_insert, key=lambda t: t[0])
for name, dynamic_axes in to_insert:
name = torch_to_onnx_input_map[name] if name in torch_to_onnx_input_map else name
ordered_inputs[name] = dynamic_axes
return ordered_inputs

Expand Down
3 changes: 1 addition & 2 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def export_pytorch(
model.to(device)
dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs)
check_dummy_inputs_are_allowed(model, dummy_inputs)
# inputs = config.ordered_inputs(model)
inputs = config.inputs
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

Expand Down
72 changes: 11 additions & 61 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from ...utils import (
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
NormalizedAudioConfig,
NormalizedConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
Expand Down Expand Up @@ -655,49 +656,8 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
return common_inputs


class Seq2SeqDecoderTextInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"input_ids",
"encoder_attention_mask",
"encoder_hidden_states",
)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = 2,
sequence_length: int = 16,
num_choices: int = 4,
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
):
super().__init__(
task,
normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
random_num_choices_range=random_num_choices_range,
)

self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt"):
if input_name != "encoder_hidden_states":
dummy_input = super().generate(input_name, framework=framework)
else:
shape = (self.batch_size, self.sequence_length, self.hidden_size)
dummy_input = (self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework), None, None)

return dummy_input


class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedAudioConfig
NORMALIZED_CONFIG_CLASS = NormalizedConfig

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
Expand All @@ -710,43 +670,33 @@ class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

DUMMY_INPUT_GENERATOR_CLASSES = (
Seq2SeqDecoderTextInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
from collections import OrderedDict

common_inputs = {
"input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
"decoder_input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"},
"encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"},
}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
return {"decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states"}

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super(OnnxConfigWithPast, self).outputs
common_outputs = super().outputs
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

def generate_dummy_inputs(self, framework: str = "pt"):
dummy_input = super().generate_dummy_inputs(framework)

common_inputs = {}
common_inputs["decoder_input_ids"] = dummy_input.pop("input_ids")
common_inputs["encoder_outputs"] = dummy_input.pop("encoder_hidden_states")

if "past_key_values" in dummy_input:
common_inputs["past_key_values"] = dummy_input.pop("past_key_values")

return common_inputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def is_pydantic_available():
DummyBboxInputGenerator,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
NormalizedAudioConfig,
NormalizedConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
Expand Down
46 changes: 41 additions & 5 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ def __getattr__(self, attr_name):
return super().__getattr__(attr_name)


class NormalizedAudioConfig(NormalizedConfig):
pass


def check_framework_is_available(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -215,6 +211,7 @@ def __init__(
):
self.task = task
self.vocab_size = normalized_config.vocab_size
self.hidden_size = normalized_config.hidden_size
if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
Expand Down Expand Up @@ -247,6 +244,45 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
)


class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = 2,
sequence_length: int = 16,
num_choices: int = 4,
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
):
super().__init__(
task,
normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
random_num_choices_range=random_num_choices_range,
)

self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt"):
if input_name == "encoder_outputs":
shape = (self.batch_size, self.sequence_length, self.hidden_size)
return (self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework), None, None)

return super().generate(input_name, framework=framework)


class DummyPastKeyValuesGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("past_key_values",)

Expand Down Expand Up @@ -421,7 +457,7 @@ class DummyAudioInputGenerator(DummyInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedAudioConfig,
normalized_config: NormalizedConfig,
batch_size: int = 2,
feature_size: int = 80,
nb_max_frames: int = 3000,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"huggingface_hub>=0.8.0",
]

TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow"]
TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow", "soundfile"]

QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.5.4"]

Expand Down

0 comments on commit af7ddd3

Please sign in to comment.