Skip to content

Commit

Permalink
resolving cherry-pick conflict:
Browse files Browse the repository at this point in the history
CONFLICT (content): Merge conflict in test/onnx/test_pytorch_onnx_no_runtime.py
Adding test/functorch
error: could not apply 2aed670... Fix ONNX operator_export_type on the new registry (pytorch#87735)
  • Loading branch information
thiagocrepaldi authored and izaitsevfb committed Dec 1, 2022
1 parent ae2fe40 commit 4c7a208
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 17 deletions.
298 changes: 298 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -32,6 +32,7 @@ def export_to_onnx(
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
opset_version: int = GLOBALS.export_onnx_opset_version,
**torch_onnx_export_kwargs,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
Expand All @@ -44,6 +45,7 @@ def export_to_onnx(
mocks: list of mocks to use during export
operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)`
opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)`
torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments
Returns:
A valid ONNX model (`onnx.ModelProto`)
"""
Expand All @@ -60,6 +62,7 @@ def export_to_onnx(
f,
operator_export_type=operator_export_type,
opset_version=opset_version,
**torch_onnx_export_kwargs,
)

# Validate ONNX graph before returning it
Expand Down Expand Up @@ -777,6 +780,301 @@ def forward(self, x):
model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"]
)

def test_dropout_script(self):

eg = torch.zeros(1, 2, 3, requires_grad=True)

@jit_utils._trace(eg)
def foo(x):
x = torch.neg(x)
return F.dropout(x)

class MyDrop(torch.nn.Module):
def forward(self, x):
return foo(x)

f = io.BytesIO()
with warnings.catch_warnings(record=True):
torch.onnx.export(MyDrop(), (eg,), f, verbose=False)

def test_pack_padded_pad_packed_trace(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

T, B, C = 3, 5, 7

class PadPackedWrapper(torch.nn.Module):
def __init__(self):
super(PadPackedWrapper, self).__init__()

def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = pad_packed_sequence(x)
return x

x = np.ones((T, B, C))
seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
# set padding value so we can test equivalence
for b in range(B):
if seq_lens[b] < T:
x[seq_lens[b] :, b, :] = 0
seq_lens = torch.from_numpy(seq_lens)
x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)

m = PadPackedWrapper()
m_traced = torch.jit.trace(
m,
(
x,
seq_lens,
),
)

y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()

y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()

self.assertEqual(y_traced, x)
self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)

# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
@common_utils.suppress_warnings
def test_rnn_trace_override(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

num_layers = 3
T, B, C = 11, 5, 7

class RNNTraceWrapper(torch.nn.Module):
def __init__(self, cell_type):
super(RNNTraceWrapper, self).__init__()
if cell_type == "RNN":
self.rnn = torch.nn.RNN(
input_size=C, hidden_size=C, num_layers=num_layers
)
elif cell_type == "LSTM":
self.rnn = torch.nn.LSTM(
input_size=C, hidden_size=C, num_layers=num_layers
)
elif cell_type == "GRU":
self.rnn = torch.nn.GRU(
input_size=C, hidden_size=C, num_layers=num_layers
)

def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = self.rnn(x)
x, _ = pad_packed_sequence(x)
return x

for cell_type in ["RNN", "LSTM", "GRU"]:
x = torch.ones(T, B, C, requires_grad=True)
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))

m = RNNTraceWrapper(cell_type)
m_traced = torch.jit.trace(
m,
(
x,
seq_lens,
),
)

y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()

y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()

self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)

def test_trace_fork_wait_inline_onnx(self):
def fork_body(x):
return torch.neg(x), torch.neg(x)

class MyMod(torch.nn.Module):
def forward(self, x):
fut = torch.jit._fork(fork_body, x)
val = torch.jit._wait(fut)
return val[1]

# smoke test for ONNX export
f = io.BytesIO()
torch.onnx.export(MyMod(), (torch.rand(3, 4),), f)

def test_trace_detach_onnx_erase(self):
class Mod(torch.nn.Module):
def forward(self, x, w):
return torch.matmul(x, w).detach()

torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))


class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.ao.quantization.default_qconfig
model = torch.ao.quantization.prepare(model)
model = torch.ao.quantization.convert(model)

_ = model(data)
input_names = ["x"]

def _export_to_onnx(model, input, input_names):
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
buf.seek(0)

model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(
model,
input,
f,
input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
opset_version=9,
)

_export_to_onnx(model, data, input_names)

@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Linear(5, 10, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv2d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@unittest.skip(
"onnx opset9 does not support quantize_per_tensor and caffe2 \
does not support conv3d"
)
def test_lower_graph_conv3d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv3d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")


if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit 4c7a208

Please sign in to comment.