From 75242065592441537b54bad96204096c028d517f Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Tue, 15 Nov 2022 06:36:24 +0000 Subject: [PATCH] Fix get_default_qat_qconfig for PT 1.13 (#88876) See https://github.com/pytorch/pytorch/pull/84329/files#r1019916766 for more context Pull Request resolved: https://github.com/pytorch/pytorch/pull/88876 Approved by: https://github.com/jgong5, https://github.com/vkuzo --- test/quantization/core/test_top_level_apis.py | 32 +++++++++++++++++++ torch/ao/quantization/qconfig.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/test/quantization/core/test_top_level_apis.py b/test/quantization/core/test_top_level_apis.py index 7343a16040d25b0..f76db1cd4139b72 100644 --- a/test/quantization/core/test_top_level_apis.py +++ b/test/quantization/core/test_top_level_apis.py @@ -59,3 +59,35 @@ def test_fake_quants(self) -> None: for observer in self.fake_quants: obs = self._get_observer_ins(observer) obs.forward(t) + + +class TestQConfig(TestCase): + + REDUCE_RANGE_DICT = { + 'fbgemm': (True, False), + 'qnnpack': (False, False), + 'onednn': (False, False), + 'x86': (True, False), + } + + def test_reduce_range_qat(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(2): + qconfig = torch.ao.quantization.get_default_qat_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.activation_post_process.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.activation_post_process.reduce_range, reduce_ranges[1]) + + def test_reduce_range(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(1): + qconfig = torch.ao.quantization.get_default_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1]) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index b75e16ef044f47f..f52bf713c6f9b5b 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -339,7 +339,7 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) - if backend == 'x86': + elif backend == 'x86': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,