Skip to content

Commit

Permalink
[Inductor][Quant] Fix PT2E Dynamic Quant regression (#125207)
Browse files Browse the repository at this point in the history
**Summary**
Fix 2 regression issues caused by previous refactor:

- Fix the issue in dequant promotion pass with dynamic quant when the dequant node is with `tensor` overload.
- Fix numerical issue in dynamic quant, since input will convert to scales' dtype (which is `double`) to do quant operatoration with previous implementation.

**TestPlan**
```
clear && python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_input_dim_exceeds_2
clear && python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_qlinear_dequant_promotion_dynamic_cpu
```

Pull Request resolved: #125207
Approved by: https://github.com/peterbell10, https://github.com/jgong5
ghstack dependencies: #124041, #124246
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed May 9, 2024
1 parent d474d79 commit 3da949b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
58 changes: 55 additions & 3 deletions test/inductor/test_mkldnn_pattern_matcher.py
Expand Up @@ -1385,6 +1385,18 @@ def test_dynamic_qlinear_qat_cpu(self):
(torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_dynamic_qlinear_input_dim_exceeds_2(self):
r"""
This testcase will quantize a single Linear Moduel.
"""
for bias in [True, False]:
self._qlinear_cpu_test_helper(
(torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand Down Expand Up @@ -1577,7 +1589,13 @@ def test_qlinear_gelu_int8_mixed_bf16(self):
(torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
)

def _qlinear_dequant_promotion_cpu_test_helper(self, inputs, int8_mixed_bf16=False):
def _qlinear_dequant_promotion_cpu_test_helper(
self,
inputs,
int8_mixed_bf16=False,
is_dynamic=False,
matcher_check_fn=None,
):
class M(torch.nn.Module):
def __init__(
self,
Expand All @@ -1595,7 +1613,7 @@ def forward(self, x):

mod = M().eval()

def matcher_check_fn():
def default_matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
# 2. dequant-linear pattern matched in quantization weight prepack * 3
Expand All @@ -1610,7 +1628,10 @@ def matcher_check_fn():
inputs,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
matcher_check_fn=matcher_check_fn
if matcher_check_fn is not None
else default_matcher_check_fn,
is_dynamic=is_dynamic,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1693,6 +1714,37 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self):
(torch.randn((2, 3, 4)),), int8_mixed_bf16=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qlinear_dequant_promotion_dynamic_cpu(self):
r"""
This testcase test if dequant node before linear is promoted correctly:
X
|
Linear1(X)
/ \
Linear2(X) Linear3(X)
\ /
Add
|
Y
"""

def matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
# 2. dequant-linear pattern matched in quantization weight prepack * 3
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
)

self._qlinear_dequant_promotion_cpu_test_helper(
(torch.randn((2, 4)),),
matcher_check_fn=matcher_check_fn,
is_dynamic=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
Expand Down
29 changes: 20 additions & 9 deletions torch/_inductor/fx_passes/quantization.py
Expand Up @@ -1246,6 +1246,7 @@ def _inner(match):
dequant_pattern_end_node = match.output_node()
if dequant_pattern_end_node.target not in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]:
Expand All @@ -1271,7 +1272,11 @@ def _inner(match):
)

if (
dequant_node.target is quantized_decomposed.dequantize_per_tensor.default
dequant_node.target
in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
and len(list(dequant_pattern_end_node.users)) > 1
):
# If dequant pattern has more than 1 users, then do dequant promoted
Expand Down Expand Up @@ -1336,6 +1341,7 @@ def clone_to_new_node(graph, source_node, user_node):
dequant_pattern_end_node = match.output_node()
assert dequant_pattern_end_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]
Expand All @@ -1345,7 +1351,10 @@ def clone_to_new_node(graph, source_node, user_node):
# * OPT(prims.convert_element_type.default) (to_bf16)
# * dequantize_per_tensor
def _find_first_node_in_dequant_pattern(_node):
if _node.target is quantized_decomposed.dequantize_per_tensor.default:
if _node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
Expand All @@ -1358,10 +1367,10 @@ def _find_first_node_in_dequant_pattern(_node):
dequant_pattern_end_node
)

assert (
dequant_pattern_start_node.target
is quantized_decomposed.dequantize_per_tensor.default
)
assert dequant_pattern_start_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]

# Clone the dequant pattern for each user node
graph = match.graph
Expand Down Expand Up @@ -2010,9 +2019,9 @@ def _generate_qlinear_weight_prepack_patterns(

def _register_dequant_promotion():
dequant_pattern_cases = itertools.product(
[torch.float32, torch.bfloat16], [True, False]
[torch.float32, torch.bfloat16], [True, False], [True, False]
)
for dtype, input_dim_exceeds_two in dequant_pattern_cases:
for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
# Case 1: int8-mixed-fp32, input dim size is 2
# Case 2: int8-mixed-fp32, input dim size exceeds 2
Expand All @@ -2036,7 +2045,9 @@ def _register_dequant_promotion():
_register_dequant_promotion_pass(
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(),
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=is_tensor_overload
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
Expand Down

1 comment on commit 3da949b

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #124041 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think there is a land race with the change https://hud.pytorch.org/pytorch/pytorch/commit/33e6791645b5950b0f39301f55b8a4a79c0ca847 (comment)

Please sign in to comment.