diff --git a/test/jit/test_export_modes.py b/test/jit/test_export_modes.py index dbf10cddc059b18..99b6ab9b912efed 100644 --- a/test/jit/test_export_modes.py +++ b/test/jit/test_export_modes.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: jit"] import io +import onnx import os import shutil import sys @@ -79,11 +80,16 @@ def forward(self, x, y): x = torch.rand(3, 4) y = torch.rand(3, 4) - torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), (x, y), - add_node_names=False, + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[2].op_type == "ATen" + @skipIfCaffe2 @skipIfNoLapack @@ -96,13 +102,17 @@ def forward(self, x, y): x = torch.rand(3, 4) y = torch.rand(3, 4) - torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), (x, y), - add_node_names=False, + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, # support for linalg.qr was added in later op set versions. opset_version=9) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[2].op_type == "ATen" # torch.fmod is using to test ONNX_ATEN. # If you plan to remove fmod from aten, or found this test failed. @@ -114,8 +124,13 @@ def forward(self, x, y): x = torch.randn(3, 4, dtype=torch.float32) y = torch.randn(3, 4, dtype=torch.float32) - torch.onnx.export_to_pretty_string( - ModelWithAtenFmod(), (x, y), - add_node_names=False, + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX_ATEN) + + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "ATen" diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 04fc984ded2b921..8515fe82fbc3f3a 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1743,13 +1743,14 @@ def _should_aten_fallback( opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes, ): + namespace, _ = name.split("::") is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) - return is_onnx_aten_export or ( - not is_exportable_aten_op and is_aten_fallback_export + return namespace == "aten" and ( + is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export) ) @@ -1844,6 +1845,24 @@ def _run_symbolic_function( env=env, ) + # when domain is "aten" and operator_export_type = ONNX_ATEN or ONNX_ATEN_FALLBACK + # create org.pytorch.aten::ATen oeprator regardless of symbolics exist or not + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + # Direct ATen export requested + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + # `overload_name` is set for non-Caffe2 builds only + return graph_context.at( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + try: # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9: @@ -1861,6 +1880,7 @@ def _run_symbolic_function( if symbolic_function_group is not None: symbolic_fn = symbolic_function_group.get(opset_version) if symbolic_fn is not None: + # TODO(justinchuby): Wrap almost identical attrs assignment or comment the difference. attrs = { k: symbolic_helper._node_get(node, k) for k in node.attributeNames() } @@ -1874,18 +1894,6 @@ def _run_symbolic_function( # Clone node to trigger ONNX shape inference return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined] - if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): - # Direct ATen export requested - outputs = node.outputsSize() - attrs["outputs"] = outputs - # `overload_name` is set for non-Caffe2 builds only - return graph_context.at( - op_name, - *inputs, - overload_name=_get_aten_op_overload_name(node), - **attrs, - ) - raise errors.UnsupportedOperatorError( domain, op_name,