Skip to content

Commit

Permalink
[Quant][FX] Lower QLinearLeakyReLU for onednn backend (#88668)
Browse files Browse the repository at this point in the history
**Summary**
Add quantization mappings for `QLinearLeakyReLU` for int8 inference for onednn backend. The fusion and lowering is supported only in FX mode.

**Test plan**
python test_quantization.py TestQuantizeFx

Pull Request resolved: #88668
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Dec 19, 2022
1 parent 8004f93 commit 9ca41a9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
35 changes: 34 additions & 1 deletion test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.ao.nn.quantized.reference as nnqr
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.multiprocessing as mp

Expand Down Expand Up @@ -5619,6 +5619,39 @@ def forward(self, x):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

@skipIfNoONEDNN
def test_linear_leaky_relu_lowering(self):
""" Test fusion and lowering of Linear - (bn -) LeakyReLU
by FX. For onednn backedn only.
"""
from torch.ao.quantization.backend_config import get_onednn_backend_config
qconfig_mapping = get_default_qconfig_mapping('onednn')
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method("dequantize"): 1,
ns.call_module(nniq.LinearLeakyReLU): 1,
ns.call_module(nn.Linear): 0,
ns.call_module(nn.LeakyReLU): 0,
}
node_occurrence_ref = {
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_method("dequantize"): 2,
}
with override_quantized_engine('onednn'):
for with_bn in [True, False]:
# test eval mode
m = LinearBnLeakyReluModel(with_bn).eval()
example_x = m.get_example_inputs()
m = prepare_fx(m, qconfig_mapping,
example_inputs=example_x,
backend_config=get_onednn_backend_config())
m_copy = copy.deepcopy(m)
m = convert_fx(m, backend_config=get_onednn_backend_config())
m_ref = convert_to_reference_fx(m_copy)

self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
m(*example_x)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
Expand Down
6 changes: 4 additions & 2 deletions torch/ao/ns/fx/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.quantization.backend_config import get_native_backend_config
Expand Down Expand Up @@ -601,6 +601,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nniqat.LinearReLU,
nniqat.LinearBn1d,
nniqd.LinearReLU,
nni.LinearLeakyReLU,
])

MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
Expand Down Expand Up @@ -631,6 +632,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nniq.ConvReLU2d,
nniq.ConvReLU3d,
nniq.LinearReLU,
nniq.LinearLeakyReLU,
])

MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
Expand Down
8 changes: 6 additions & 2 deletions torch/ao/quantization/fx/_lower_to_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from torch.fx.graph import Graph
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
Expand Down Expand Up @@ -253,6 +253,10 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA
# 2) The replacement static quantized module class for lowering
STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
# TODO: LinearLeakyReLU is registered as global but it is only fused and
# lowered when ondnn's backend config is used. Maybe need to separate
# registration and lowering functions for different backends in the future.
nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU),
nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),
Expand Down
3 changes: 2 additions & 1 deletion torch/ao/quantization/qconfig_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \

if backend == 'onednn':
qconfig_mapping.set_object_type(torch.nn.LeakyReLU, qconfig) \
qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) \
.set_object_type(torch.nn.LeakyReLU, qconfig) \
.set_object_type(torch.nn.functional.leaky_relu, qconfig)

# Use special observers for ops with fixed qparams
Expand Down
1 change: 1 addition & 0 deletions torch/ao/quantization/quantization_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nni.LinearLeakyReLU: nniq.LinearLeakyReLU,
nniqat.ConvBn1d: nnq.Conv1d,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBn3d: nnq.Conv3d,
Expand Down

0 comments on commit 9ca41a9

Please sign in to comment.