diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 5f2ce3fa657a1ee..ff48531cdd8555e 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -32,6 +32,7 @@ def export_to_onnx( mocks: Optional[Iterable] = None, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, opset_version: int = GLOBALS.export_onnx_opset_version, + **torch_onnx_export_kwargs, ) -> onnx.ModelProto: """Exports `model(input)` to ONNX and returns it. @@ -44,6 +45,7 @@ def export_to_onnx( mocks: list of mocks to use during export operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)` opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)` + torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments Returns: A valid ONNX model (`onnx.ModelProto`) """ @@ -60,6 +62,7 @@ def export_to_onnx( f, operator_export_type=operator_export_type, opset_version=opset_version, + **torch_onnx_export_kwargs, ) # Validate ONNX graph before returning it @@ -777,6 +780,301 @@ def forward(self, x): model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"] ) + def test_dropout_script(self): + + eg = torch.zeros(1, 2, 3, requires_grad=True) + + @jit_utils._trace(eg) + def foo(x): + x = torch.neg(x) + return F.dropout(x) + + class MyDrop(torch.nn.Module): + def forward(self, x): + return foo(x) + + f = io.BytesIO() + with warnings.catch_warnings(record=True): + torch.onnx.export(MyDrop(), (eg,), f, verbose=False) + + def test_pack_padded_pad_packed_trace(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + T, B, C = 3, 5, 7 + + class PadPackedWrapper(torch.nn.Module): + def __init__(self): + super(PadPackedWrapper, self).__init__() + + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = pad_packed_sequence(x) + return x + + x = np.ones((T, B, C)) + seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32) + # set padding value so we can test equivalence + for b in range(B): + if seq_lens[b] < T: + x[seq_lens[b] :, b, :] = 0 + seq_lens = torch.from_numpy(seq_lens) + x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True) + + m = PadPackedWrapper() + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, x) + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f, verbose=False) + + # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. + @common_utils.suppress_warnings + def test_rnn_trace_override(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + num_layers = 3 + T, B, C = 11, 5, 7 + + class RNNTraceWrapper(torch.nn.Module): + def __init__(self, cell_type): + super(RNNTraceWrapper, self).__init__() + if cell_type == "RNN": + self.rnn = torch.nn.RNN( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "LSTM": + self.rnn = torch.nn.LSTM( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "GRU": + self.rnn = torch.nn.GRU( + input_size=C, hidden_size=C, num_layers=num_layers + ) + + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = self.rnn(x) + x, _ = pad_packed_sequence(x) + return x + + for cell_type in ["RNN", "LSTM", "GRU"]: + x = torch.ones(T, B, C, requires_grad=True) + seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) + + m = RNNTraceWrapper(cell_type) + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f, verbose=False) + + def test_trace_fork_wait_inline_onnx(self): + def fork_body(x): + return torch.neg(x), torch.neg(x) + + class MyMod(torch.nn.Module): + def forward(self, x): + fut = torch.jit._fork(fork_body, x) + val = torch.jit._wait(fut) + return val[1] + + # smoke test for ONNX export + f = io.BytesIO() + torch.onnx.export(MyMod(), (torch.rand(3, 4),), f) + + def test_trace_detach_onnx_erase(self): + class Mod(torch.nn.Module): + def forward(self, x, w): + return torch.matmul(x, w).detach() + + torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + + +class TestQuantizeEagerONNXExport(common_utils.TestCase): + def _test_lower_graph_impl(self, model, data): + model.qconfig = torch.ao.quantization.default_qconfig + model = torch.ao.quantization.prepare(model) + model = torch.ao.quantization.convert(model) + + _ = model(data) + input_names = ["x"] + + def _export_to_onnx(model, input, input_names): + traced = torch.jit.trace(model, input) + buf = io.BytesIO() + torch.jit.save(traced, buf) + buf.seek(0) + + model = torch.jit.load(buf) + f = io.BytesIO() + torch.onnx.export( + model, + input, + f, + input_names=input_names, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + opset_version=9, + ) + + _export_to_onnx(model, data, input_names) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_linear(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Linear(5, 10, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 2, 5).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_conv2d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv2d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @unittest.skip( + "onnx opset9 does not support quantize_per_tensor and caffe2 \ + does not support conv3d" + ) + def test_lower_graph_conv3d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv3d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_utils.skipIfNoCaffe2 + def test_caffe2_aten_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfNoCaffe2 + def test_caffe2_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" + + @common_utils.skipIfCaffe2 + def test_aten_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfCaffe2 + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 04fc984ded2b921..83482cac0598f03 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: @_beartype.beartype def _should_aten_fallback( - name: str, - opset_version: int, - operator_export_type: _C_onnx.OperatorExportTypes, + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes ): + # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) - return is_onnx_aten_export or ( - not is_exportable_aten_op and is_aten_fallback_export + is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK + + return name.startswith("aten::") and ( + ((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build) + or (not is_exportable_aten_op and is_aten_fallback_export) ) @@ -1844,6 +1849,21 @@ def _run_symbolic_function( env=env, ) + # Direct ATen export requested + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.at( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + try: # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9: @@ -1861,6 +1881,7 @@ def _run_symbolic_function( if symbolic_function_group is not None: symbolic_fn = symbolic_function_group.get(opset_version) if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. attrs = { k: symbolic_helper._node_get(node, k) for k in node.attributeNames() } @@ -1874,18 +1895,6 @@ def _run_symbolic_function( # Clone node to trigger ONNX shape inference return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined] - if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): - # Direct ATen export requested - outputs = node.outputsSize() - attrs["outputs"] = outputs - # `overload_name` is set for non-Caffe2 builds only - return graph_context.at( - op_name, - *inputs, - overload_name=_get_aten_op_overload_name(node), - **attrs, - ) - raise errors.UnsupportedOperatorError( domain, op_name,