Skip to content

Commit

Permalink
Fix regression for onnx export with caffe2 builds
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagocrepaldi committed Oct 31, 2022
1 parent ccdd1c1 commit 4379d65
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions torch/onnx/utils.py
Expand Up @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:

@_beartype.beartype
def _should_aten_fallback(
name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available

is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export)
(is_onnx_aten_export and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)


Expand Down Expand Up @@ -1845,16 +1850,7 @@ def _run_symbolic_function(
)

# Direct ATen export requested
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and not registration.registry.is_registered_op(ns_op_name, opset_version)
) or (
not _C_onnx._CAFFE2_ATEN_FALLBACK
and _should_aten_fallback(ns_op_name, opset_version, operator_export_type)
):
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
for k in node.attributeNames()
Expand Down

0 comments on commit 4379d65

Please sign in to comment.