Skip to content

Commit

Permalink
[Quant][Fx] Fix issue: qconfig_mappings of onednn backend are not cor…
Browse files Browse the repository at this point in the history
…rectly set for fused modules (#91297)

**Summary**
For onednn quantization backend only.
Currently, FX fusion requires that all separate ops in a fused module/op have the same `qconfig`. To support `linear - leaky_relu` and `linear - tanh` fusion with onednn backend, we previously explicitly set the same `qconfig` to `linear`, `leaky_relu` and `tanh`. However, this brings two problems:
- It breaks fusion of `linear - relu` since `relu` does not have the same `qconfig` as `linear` does. And it does not look good if we set `qconfig` to all these ops. They should use a global `qconfig` by default.
- `Tanh` requires `fixed_qparams_qconfig` otherwise it is not quantized. So, we cannot set another `qconfig` to `tanh`.

Looks like there is not a straightforward way to solve the problems. This PR fixes them by the following:
- Do not set `qconfig` to these ops so that these ops use a global `qconfig` and `linear - relu` and `linear - leaky_relu` can be fused correctly.
- Set the same `qconfig` to `linear` and `tanh` manually by users when they want to fuse `linear - tanh` with onednn backend.

A known issue still exists: users cannot fuse `linear - tanh` and quantize standalone `tanh` at the same time.

**Test plan**
python test/test_quantization.py -k test_qconfig_dict_with_fused_modules

Pull Request resolved: #91297
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Jan 26, 2023
1 parent 913866e commit 1d03a6a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
14 changes: 13 additions & 1 deletion test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,6 +1912,7 @@ def forward(self, x):
self.checkGraphModuleNodes(m, expected_node_list=node_list)


@override_qengines
def test_qconfig_dict_with_fused_modules(self):
class LinearReLUModel(torch.nn.Module):
def __init__(self, relu):
Expand Down Expand Up @@ -1951,7 +1952,8 @@ def forward(self, x):
for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
m = model(relu).eval()
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
qengine = torch.backends.quantized.engine
qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine)
# should not crash as in https://github.com/pytorch/pytorch/issues/75825
prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))

Expand Down Expand Up @@ -5796,6 +5798,16 @@ def test_linear_tanh_lowering(self):
"""
from torch.ao.quantization.backend_config import get_onednn_backend_config
qconfig_mapping = get_default_qconfig_mapping('onednn')
# TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
# Need to be able to support fusion of ops with different qconfigs
# Since tanh must have 'fixed_qparams_qconfig' while linear should use
# the global qconfig, we need to set qconfigs for them manually here for
# fusion and cannot put such configs in onednn's default qconfig_mapping.
# Known issue:
# Cannot fuse linear - tanh and quantize standalone tanh at the same time.
qconfig = get_default_qconfig('onednn')
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig)
qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig)
with override_quantized_engine('onednn'):
m = LinearTanhModel()
self._test_linear_activation_fusion_lowering_helper(
Expand Down
11 changes: 2 additions & 9 deletions torch/ao/quantization/qconfig_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,8 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)

# QConfig for fused ops for onednn backend
# Separate ops are required to have the same qconfig as fused ops
# TODO: we should be able to configure qconfig for patterns
if backend == 'onednn':
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.LeakyReLU, qconfig) \
.set_object_type(torch.nn.functional.leaky_relu, qconfig) \
.set_object_type(torch.nn.Tanh, qconfig) \
.set_object_type(torch.nn.functional.tanh, qconfig)
# TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
# Need to be able to support fusion of ops with different qconfigs

return qconfig_mapping

Expand Down

0 comments on commit 1d03a6a

Please sign in to comment.