From 4379d65d24ce38ba7902e81d23ed8d699cd6c605 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Mon, 31 Oct 2022 18:19:44 -0400 Subject: [PATCH] Fix regression for onnx export with caffe2 builds --- torch/onnx/utils.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 201028173cd0d03..dfbf31e5ffc43f4 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: @_beartype.beartype def _should_aten_fallback( - name: str, - opset_version: int, - operator_export_type: _C_onnx.OperatorExportTypes, + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes ): + # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) + is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK + return name.startswith("aten::") and ( - is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export) + (is_onnx_aten_export and not is_caffe2_build) + or (not is_exportable_aten_op and is_aten_fallback_export) ) @@ -1845,16 +1850,7 @@ def _run_symbolic_function( ) # Direct ATen export requested - # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, - # an aten::ATen operator is created regardless of symbolics existence - # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and not registration.registry.is_registered_op(ns_op_name, opset_version) - ) or ( - not _C_onnx._CAFFE2_ATEN_FALLBACK - and _should_aten_fallback(ns_op_name, opset_version, operator_export_type) - ): + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): attrs = { k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) for k in node.attributeNames()