Skip to content

Commit

Permalink
Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (pytorch#88504)
Browse files Browse the repository at this point in the history
Follow-up for pytorch#87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes pytorch#87313

Pull Request resolved: pytorch#88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

(cherry picked from commit 5f0783b)
  • Loading branch information
thiagocrepaldi authored and izaitsevfb committed Dec 3, 2022
1 parent 642edcd commit 93b3892
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 35 deletions.
172 changes: 141 additions & 31 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, utils
from torch.onnx import OperatorExportTypes, symbolic_helper, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -781,50 +781,60 @@ def forward(self, x):
)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
def test_caffe2_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN,
OperatorExportTypes.ONNX_ATEN_FALLBACK,
):
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
def test_caffe2_onnx_aten_must_not_fallback(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN_FALLBACK,
OperatorExportTypes.ONNX_ATEN,
):
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
Expand Down Expand Up @@ -865,6 +875,106 @@ def forward(self, x, y):
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")

@common_utils.skipIfCaffe2
def test_onnx_aten_fallback_must_not_fallback(self):
# For BUILD_CAFFE2=0, aten fallback only when not exportable
class ONNXExportable(torch.nn.Module):
def __init__(self):
super(ONNXExportable, self).__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = torch.nn.Linear(12, 8)
self.fc2 = torch.nn.Linear(8, 4)
self.fc3 = torch.nn.Linear(4, 6)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = x.view((-1, 12))
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
h = F.relu(self.fc3(h))
h = self.dequant(h)
return h

dummy_input = torch.randn(12)
f = io.BytesIO()
torch.onnx.export(
ONNXExportable(),
(dummy_input,),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
all_aten_nodes = [
p
for p in onnx_model.graph.node
if p.op_type == "ATen" and p.domain == "org.pytorch.aten"
]
self.assertEqual(len(all_aten_nodes), 0)


class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.ao.quantization.default_qconfig
model = torch.ao.quantization.prepare(model)
model = torch.ao.quantization.convert(model)

_ = model(data)
input_names = ["x"]

def _export_to_onnx(model, input, input_names):
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
buf.seek(0)

model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(
model,
input,
f,
input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
opset_version=9,
)

_export_to_onnx(model, data, input_names)

@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Linear(5, 10, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv2d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@unittest.skip(
"onnx opset9 does not support quantize_per_tensor and caffe2 \
does not support conv3d"
)
def test_lower_graph_conv3d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv3d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)


if __name__ == "__main__":
common_utils.run_tests()
19 changes: 15 additions & 4 deletions torch/onnx/utils.py
Expand Up @@ -1752,10 +1752,21 @@ def _should_aten_fallback(
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)
if not name.startswith("aten::"):
return False

if is_caffe2_build:
if (
is_onnx_aten_export or is_aten_fallback_export
) and not is_exportable_aten_op:
return True
else:
if is_onnx_aten_export or (
is_aten_fallback_export and not is_exportable_aten_op
):
return True

return False


@_beartype.beartype
Expand Down

0 comments on commit 93b3892

Please sign in to comment.