From 93b38923afe74c77fdf4426c5617a7c9a81cd387 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 11 Nov 2022 09:43:46 -0800 Subject: [PATCH] Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504) Follow-up for #87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes #87313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao (cherry picked from commit 5f0783bd6d27a0a239263b943d626c533b8b9a90) --- test/onnx/test_pytorch_onnx_no_runtime.py | 172 ++++++++++++++++++---- torch/onnx/utils.py | 19 ++- 2 files changed, 156 insertions(+), 35 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 38564bca4f75470..9dac36da52dbbf5 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import symbolic_helper, utils +from torch.onnx import OperatorExportTypes, symbolic_helper, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import registration from torch.testing._internal import common_utils @@ -781,50 +781,60 @@ def forward(self, x): ) @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback(self): + def test_caffe2_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): def forward(self, x, y): abcd = x + y defg = torch.linalg.qr(abcd) return defg - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN, + OperatorExportTypes.ONNX_ATEN_FALLBACK, + ): + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten(self): + def test_caffe2_onnx_aten_must_not_fallback(self): class ModelWithAtenFmod(torch.nn.Module): def forward(self, x, y): return torch.fmod(x, y) - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN_FALLBACK, + OperatorExportTypes.ONNX_ATEN, + ): + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" @common_utils.skipIfCaffe2 - def test_aten_fallback(self): + def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): def forward(self, x, y): abcd = x + y @@ -865,6 +875,106 @@ def forward(self, x, y): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "fmod", "Tensor") + @common_utils.skipIfCaffe2 + def test_onnx_aten_fallback_must_not_fallback(self): + # For BUILD_CAFFE2=0, aten fallback only when not exportable + class ONNXExportable(torch.nn.Module): + def __init__(self): + super(ONNXExportable, self).__init__() + self.quant = torch.quantization.QuantStub() + self.fc1 = torch.nn.Linear(12, 8) + self.fc2 = torch.nn.Linear(8, 4) + self.fc3 = torch.nn.Linear(4, 6) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = x.view((-1, 12)) + h = F.relu(self.fc1(x)) + h = F.relu(self.fc2(h)) + h = F.relu(self.fc3(h)) + h = self.dequant(h) + return h + + dummy_input = torch.randn(12) + f = io.BytesIO() + torch.onnx.export( + ONNXExportable(), + (dummy_input,), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + all_aten_nodes = [ + p + for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten" + ] + self.assertEqual(len(all_aten_nodes), 0) + + +class TestQuantizeEagerONNXExport(common_utils.TestCase): + def _test_lower_graph_impl(self, model, data): + model.qconfig = torch.ao.quantization.default_qconfig + model = torch.ao.quantization.prepare(model) + model = torch.ao.quantization.convert(model) + + _ = model(data) + input_names = ["x"] + + def _export_to_onnx(model, input, input_names): + traced = torch.jit.trace(model, input) + buf = io.BytesIO() + torch.jit.save(traced, buf) + buf.seek(0) + + model = torch.jit.load(buf) + f = io.BytesIO() + torch.onnx.export( + model, + input, + f, + input_names=input_names, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + opset_version=9, + ) + + _export_to_onnx(model, data, input_names) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_linear(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Linear(5, 10, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 2, 5).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_conv2d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv2d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @unittest.skip( + "onnx opset9 does not support quantize_per_tensor and caffe2 \ + does not support conv3d" + ) + def test_lower_graph_conv3d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv3d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 83482cac0598f03..d43d7d09fd76826 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1752,10 +1752,21 @@ def _should_aten_fallback( ) is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK - return name.startswith("aten::") and ( - ((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build) - or (not is_exportable_aten_op and is_aten_fallback_export) - ) + if not name.startswith("aten::"): + return False + + if is_caffe2_build: + if ( + is_onnx_aten_export or is_aten_fallback_export + ) and not is_exportable_aten_op: + return True + else: + if is_onnx_aten_export or ( + is_aten_fallback_export and not is_exportable_aten_op + ): + return True + + return False @_beartype.beartype