Skip to content

Commit

Permalink
Fix regression for onnx export with caffe2 builds
Browse files Browse the repository at this point in the history
Signed-off-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
  • Loading branch information
thiagocrepaldi committed Nov 1, 2022
1 parent 77af8e1 commit ae7b749
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
4 changes: 4 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -35,6 +35,7 @@ def export_to_onnx(
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
opset_version: int = GLOBALS.export_onnx_opset_version,
**torch_onnx_export_kwargs,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
Expand All @@ -47,6 +48,7 @@ def export_to_onnx(
mocks: list of mocks to use during export
operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)`
opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)`
torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments
Returns:
A valid ONNX model (`onnx.ModelProto`)
"""
Expand All @@ -63,6 +65,7 @@ def export_to_onnx(
f,
operator_export_type=operator_export_type,
opset_version=opset_version,
**torch_onnx_export_kwargs,
)

# Validate ONNX graph before returning it
Expand Down Expand Up @@ -436,6 +439,7 @@ def forward(self, x):
)
],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
verbose=True,
)
self.assertAtenOp(onnx_model, "clamp", "Tensor")

Expand Down
24 changes: 10 additions & 14 deletions torch/onnx/utils.py
Expand Up @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:

@_beartype.beartype
def _should_aten_fallback(
name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available

is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export)
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)


Expand Down Expand Up @@ -1845,16 +1850,7 @@ def _run_symbolic_function(
)

# Direct ATen export requested
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and not registration.registry.is_registered_op(ns_op_name, opset_version)
) or (
not _C_onnx._CAFFE2_ATEN_FALLBACK
and _should_aten_fallback(ns_op_name, opset_version, operator_export_type)
):
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
for k in node.attributeNames()
Expand Down

0 comments on commit ae7b749

Please sign in to comment.