Skip to content

Commit

Permalink
[ao][ns] Replacing List[QConfigMapping] in PNP (pytorch#86922)
Browse files Browse the repository at this point in the history
Summary: Added QConfigMultiMapping which is essentially a
List[QConfigMapping] with set methods and dedicated handling to
avoid unwanted matches and improve UX.

note: the from __future__ import annotations line caused weird errors when the
QConfigMultiMapping class was put in _numeric_suite_fx.py so it was moved.

Test Plan: python test/test_quantization.py TestFxNumericSuiteNShadows

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#86922
Approved by: https://github.com/vkuzo
  • Loading branch information
HDCharles authored and kulinseth committed Dec 9, 2022
1 parent 9f0ad2d commit 7455967
Show file tree
Hide file tree
Showing 3 changed files with 452 additions and 37 deletions.
229 changes: 203 additions & 26 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LSTMwithHiddenDynamicModel,
SparseNNModel,
skip_if_no_torchvision,
TwoLayerLinearModel
)
from torch.ao.quantization.quantization_mappings import (
get_default_static_quant_module_mappings,
Expand Down Expand Up @@ -82,6 +83,7 @@
loggers_set_enabled,
loggers_set_save_activations,
)
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
from torch.ao.quantization.backend_config import get_native_backend_config
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers

Expand Down Expand Up @@ -2096,6 +2098,7 @@ def _test_impl(self, m, example_input, qconfig_mappings):

results = extract_results_n_shadows_model(msq)
print_comparisons_n_shadows_model(results)
return msq

def test_linear_mod(self):
class M(nn.Module):
Expand All @@ -2110,9 +2113,8 @@ def forward(self, x):
m = M().eval()
example_input = (torch.randn(2, 2),)

qconfig_mappings = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
]
qconfig_mappings = \
QConfigMultiMapping().set_global([torch.quantization.default_qconfig])
self._test_impl(m, example_input, qconfig_mappings)

def test_linear_relu_mod(self):
Expand All @@ -2132,10 +2134,12 @@ def forward(self, x):
m = M().eval()
example_input = (torch.randn(2, 2),)

qconfig_mappings = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig),
]
qconfig_mappings = (
QConfigMultiMapping().set_global([
torch.quantization.default_qconfig,
torch.quantization.default_dynamic_qconfig
])
)
self._test_impl(m, example_input, qconfig_mappings)

def test_conv_bn_relu_mod(self):
Expand All @@ -2154,10 +2158,12 @@ def forward(self, x):

m = M().eval()
example_input = (torch.randn(32, 1, 16, 16),)
qconfig_mappings = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
QConfigMapping().set_global(torch.quantization.default_per_channel_qconfig),
]

qconfig_mappings = QConfigMultiMapping() \
.set_global([
torch.quantization.default_qconfig,
torch.quantization.default_per_channel_qconfig
])
self._test_impl(m, example_input, qconfig_mappings)

def test_functions(self):
Expand Down Expand Up @@ -2194,10 +2200,8 @@ def forward(self, x):
m = M().eval()
example_input = (torch.randn(2, 2),)

qconfig_mappings = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
# QConfigMapping().set_global(torch.quantization.default_per_channel_qconfig),
]
qconfig_mappings = QConfigMultiMapping() \
.set_global([torch.quantization.default_qconfig])
self._test_impl(m, example_input, qconfig_mappings)

def test_partial_qconfig_mapping(self):
Expand All @@ -2220,19 +2224,17 @@ def forward(self, x):
example_input = (torch.randn(2, 2),)
qconfig = torch.ao.quantization.default_qconfig

qconfig_mappings = [
QConfigMapping().set_global(None)
.set_object_type(F.linear, qconfig)
.set_object_type(F.relu, qconfig),
]
qconfig_mappings = QConfigMultiMapping() \
.set_object_type(F.linear, [qconfig]) \
.set_object_type(F.relu, [qconfig])
self._test_impl(m, example_input, qconfig_mappings)

def test_logger_enabled_and_save_activations_flags(self):
m = nn.Sequential(nn.Linear(1, 1)).eval()
example_input = (torch.randn(1, 1),)
qconfig_mappings = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
]

qconfig_mappings = QConfigMultiMapping() \
.set_global([torch.quantization.default_qconfig])
backend_config = get_native_backend_config()

msp = prepare_n_shadows_model(
Expand Down Expand Up @@ -2281,12 +2283,187 @@ def test_mobilenet_v2(self):
pretrained=False, quantize=False).eval()
example_input = (torch.randn(1, 3, 224, 224),)

qconfig_mappings = [
qconfig_mappings = QConfigMultiMapping() \
.set_global([torch.quantization.default_qconfig, torch.quantization.default_dynamic_qconfig])

self._test_impl(m, example_input, qconfig_mappings)

def test_qconfig_multi_mapping_deduplication(self):
# check that insertion deduplicates qconfigs
qconfig_multi_mapping = QConfigMultiMapping().set_global(
[torch.quantization.default_qconfig, torch.quantization.default_qconfig]
)
self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 1)

def test_qconfig_multi_mapping_insert_padding(self):
# test that inserting a higher priority qconfig style with fewer elements than a lower priority qconfig will
# result in adding None to the extra QConfigMappings at that same style+key
qconfig_multi_mapping = (
QConfigMultiMapping()
.set_global(
[
torch.quantization.default_qconfig,
torch.quantization.default_dynamic_qconfig,
]
)
.set_object_type(torch.nn.Linear, [torch.quantization.default_qconfig])
.set_module_name_regex("fc", [torch.quantization.default_qconfig])
.set_module_name("fc2", [torch.quantization.default_qconfig])
.set_module_name_object_type_order(
"", nn.Linear, 0, [torch.quantization.default_qconfig]
)
)

self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[
torch.nn.Linear
],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[
"fc"
],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[
1
].module_name_object_type_order_qconfigs[("", nn.Linear, 0)],
None,
)

def test_qconfig_multi_mapping_retroactive_padding(self):
# test that inserting a lower priority qconfig style with more elements thhan lower priority qconfig styles
# will result in the new QConfigMapping having None at all previously existing styles+keys
qconfig_multi_mapping = (
QConfigMultiMapping()
.set_object_type(torch.nn.Linear, [torch.quantization.default_qconfig])
.set_module_name_regex("fc", [torch.quantization.default_qconfig])
.set_module_name("fc2", [torch.quantization.default_qconfig])
.set_module_name_object_type_order(
"", nn.Linear, 0, [torch.quantization.default_qconfig]
)
.set_global(
[
torch.quantization.default_qconfig,
torch.quantization.default_dynamic_qconfig,
]
)
)

self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[
torch.nn.Linear
],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[
"fc"
],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
None,
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[
1
].module_name_object_type_order_qconfigs[("", nn.Linear, 0)],
None,
)

def test_qconfig_multi_mapping_end_to_end(self):
# test that the prepare/convert_n_shadows_model works as expected
# with qconfig_multi_mapping and avoids unwanted matches

m = TwoLayerLinearModel().eval()
example_input = m.get_example_inputs()

qconfig_multi_mapping = (
QConfigMultiMapping()
.set_global(
[
torch.quantization.default_qconfig,
torch.quantization.default_dynamic_qconfig,
]
)
.set_module_name("fc2", [None, torch.quantization.default_qconfig])
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
None,
)
msq = self._test_impl(m, example_input, qconfig_multi_mapping)

self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0)
self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2)

def test_qconfig_multi_mapping_from_list(self):
# test QConfigMultiMapping.from_list_qconfig_mapping works as expected

m = TwoLayerLinearModel().eval()
example_input = m.get_example_inputs()

qconfig_mappings_list = [
QConfigMapping().set_global(torch.quantization.default_qconfig),
QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig),
QConfigMapping()
.set_global(torch.quantization.default_dynamic_qconfig)
.set_module_name("fc2", torch.quantization.default_qconfig),
]
self._test_impl(m, example_input, qconfig_mappings)

qconfig_multi_mapping = QConfigMultiMapping().from_list_qconfig_mapping(
qconfig_mappings_list
)
self.assertEqual(
qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
None,
)

msq = self._test_impl(m, example_input, qconfig_multi_mapping)

self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0)
self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2)

def test_qconfig_multi_mapping_ordering(self):
# test that the module ordering ignores None

m = TwoLayerLinearModel().eval()
example_input = m.get_example_inputs()
qconfig_multi_mapping = (
QConfigMultiMapping()
.set_global(
[
torch.ao.quantization.default_qconfig,
torch.ao.quantization.default_dynamic_qconfig,
]
)
.set_module_name(
"fc2",
[
None,
torch.ao.quantization.default_dynamic_qconfig,
torch.ao.quantization.default_qat_qconfig_v2,
],
)
)
self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 2)
msq = self._test_impl(m, example_input, qconfig_multi_mapping)

self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
self.checkDynamicQuantizedLinear(msq.shadow_wrapper_1_1.mod_0, torch.qint8)
self.checkQuantizedLinear(msq.shadow_wrapper_1_2.mod_0)

class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
"""
Expand Down
18 changes: 7 additions & 11 deletions torch/ao/ns/_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@
NSResultsType,
NSNodeTargetType,
)

from torch.ao.quantization import (
QConfigMapping,
)
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
Expand All @@ -138,6 +134,7 @@
print_n_shadows_summary,
handle_subgraph,
)
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping

from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type

Expand Down Expand Up @@ -753,7 +750,7 @@ def extend_logger_results_with_comparison(
def prepare_n_shadows_model(
model: torch.nn.Module,
example_inputs: Any,
qconfig_mappings: List[QConfigMapping],
qconfig_multi_mapping: QConfigMultiMapping,
backend_config: BackendConfig,
) -> torch.nn.Module:
"""
Expand All @@ -770,9 +767,9 @@ def prepare_n_shadows_model(
args_kwargs_m -> op_m -> output_m
| |
|---------------------------> mod_with_op_m_transformed_with_qconfig_i
|---------------------------> mod_with_op_m_transformed_with_qconfig_n
Where mod_with_op_m_transformed_with_qconfig_i is a submodule, and its
Where mod_with_op_m_transformed_with_qconfig_n is a submodule, and its
inner graph looks like
.. code::
Expand All @@ -790,8 +787,7 @@ def prepare_n_shadows_model(
1. add deduplication for qconfigs per subgraph
2. figure out a better way to name the output structure
3. return a results data structure instead of printing it out
4. make specifying sets of QConfigMapping more user friendly
5. add examples to docblocks
4. add examples to docblocks
"""

tracer = quantize_fx.QuantizationTracer([], [])
Expand Down Expand Up @@ -822,7 +818,7 @@ def prepare_n_shadows_model(
# generate node to qconfig for each subgraph
# TODO(future PR): deduplicate repeating entries
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
for qconfig_mapping in qconfig_mappings:
for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
node_name_to_qconfig = generate_node_name_to_qconfig(
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
list_of_node_name_to_qconfig.append(node_name_to_qconfig)
Expand All @@ -838,7 +834,7 @@ def prepare_n_shadows_model(
enumerate(subgraphs_dedup.items()):
handle_subgraph(
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
qconfig_mappings, list_of_node_name_to_qconfig)
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig)

mt.recompile()
return mt
Expand Down

0 comments on commit 7455967

Please sign in to comment.