From 77af8e14d98d64d0b0369b46a3e1e3c47fb51555 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 25 Oct 2022 18:07:02 -0400 Subject: [PATCH 1/2] Fix ONNX operator_export_type on the new registry --- test/onnx/test_pytorch_onnx_no_runtime.py | 85 +++++++++++++++++++++++ torch/onnx/utils.py | 41 +++++++---- 2 files changed, 112 insertions(+), 14 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index c30ee46a3422664..8ed183156f0cacc 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -994,6 +994,91 @@ def test_lower_graph_conv3d(self): data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) + @common_utils.skipIfNoCaffe2 + def test_caffe2_aten_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfNoCaffe2 + def test_caffe2_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" + + @common_utils.skipIfCaffe2 + def test_aten_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfCaffe2 + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 251e6be09e9828d..201028173cd0d03 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1748,8 +1748,8 @@ def _should_aten_fallback( is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) - return is_onnx_aten_export or ( - not is_exportable_aten_op and is_aten_fallback_export + return name.startswith("aten::") and ( + is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export) ) @@ -1844,6 +1844,30 @@ def _run_symbolic_function( env=env, ) + # Direct ATen export requested + # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available + if ( + _C_onnx._CAFFE2_ATEN_FALLBACK + and not registration.registry.is_registered_op(ns_op_name, opset_version) + ) or ( + not _C_onnx._CAFFE2_ATEN_FALLBACK + and _should_aten_fallback(ns_op_name, opset_version, operator_export_type) + ): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.at( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + try: # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9: @@ -1861,6 +1885,7 @@ def _run_symbolic_function( if symbolic_function_group is not None: symbolic_fn = symbolic_function_group.get(opset_version) if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. attrs = { k: symbolic_helper._node_get(node, k) for k in node.attributeNames() } @@ -1874,18 +1899,6 @@ def _run_symbolic_function( # Clone node to trigger ONNX shape inference return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined] - if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): - # Direct ATen export requested - outputs = node.outputsSize() - attrs["outputs"] = outputs - # `overload_name` is set for non-Caffe2 builds only - return graph_context.at( - op_name, - *inputs, - overload_name=_get_aten_op_overload_name(node), - **attrs, - ) - raise errors.UnsupportedOperatorError( symbolic_function_name, opset_version, From 14ec6aa6c9647e21a2c8bd2e44785210676fa3cc Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Mon, 31 Oct 2022 18:19:44 -0400 Subject: [PATCH 2/2] Fix regression for onnx export with caffe2 builds Signed-off-by: Thiago Crepaldi --- test/onnx/test_pytorch_onnx_no_runtime.py | 3 +++ torch/onnx/utils.py | 24 ++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 8ed183156f0cacc..1095406e311dbe3 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -35,6 +35,7 @@ def export_to_onnx( mocks: Optional[Iterable] = None, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, opset_version: int = GLOBALS.export_onnx_opset_version, + **torch_onnx_export_kwargs, ) -> onnx.ModelProto: """Exports `model(input)` to ONNX and returns it. @@ -47,6 +48,7 @@ def export_to_onnx( mocks: list of mocks to use during export operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)` opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)` + torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments Returns: A valid ONNX model (`onnx.ModelProto`) """ @@ -63,6 +65,7 @@ def export_to_onnx( f, operator_export_type=operator_export_type, opset_version=opset_version, + **torch_onnx_export_kwargs, ) # Validate ONNX graph before returning it diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 201028173cd0d03..ff0ef755968d3a5 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: @_beartype.beartype def _should_aten_fallback( - name: str, - opset_version: int, - operator_export_type: _C_onnx.OperatorExportTypes, + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes ): + # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) + is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK + return name.startswith("aten::") and ( - is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export) + ((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build) + or (not is_exportable_aten_op and is_aten_fallback_export) ) @@ -1845,16 +1850,7 @@ def _run_symbolic_function( ) # Direct ATen export requested - # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, - # an aten::ATen operator is created regardless of symbolics existence - # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and not registration.registry.is_registered_op(ns_op_name, opset_version) - ) or ( - not _C_onnx._CAFFE2_ATEN_FALLBACK - and _should_aten_fallback(ns_op_name, opset_version, operator_export_type) - ): + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): attrs = { k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) for k in node.attributeNames()