diff --git a/tests/test_importhook.py b/tests/test_importhook.py index bbafbb57..f193882b 100644 --- a/tests/test_importhook.py +++ b/tests/test_importhook.py @@ -6,7 +6,7 @@ import pytest -from typeguard.importhook import install_import_hook +from typeguard.importhook import TypeguardFinder, install_import_hook this_dir = Path(__file__).parent dummy_module_path = this_dir / 'dummymodule.py' @@ -99,3 +99,20 @@ def test_inner_class_classmethod(dummymodule): def test_inner_class_staticmethod(dummymodule): retval = dummymodule.Outer.create_inner_staticmethod() assert retval.__class__.__qualname__ == 'Outer.Inner' + + +def test_package_name_matching(): + """ + The path finder only matches configured (sub)packages. + """ + packages = ["ham", "spam.eggs"] + dummy_original_pathfinder = None + finder = TypeguardFinder(packages, dummy_original_pathfinder) + + assert finder.should_instrument("ham") + assert finder.should_instrument("ham.eggs") + assert finder.should_instrument("spam.eggs") + + assert not finder.should_instrument("spam") + assert not finder.should_instrument("ha") + assert not finder.should_instrument("spam_eggs") diff --git a/typeguard/importhook.py b/typeguard/importhook.py index 20a1a623..3e5ce872 100644 --- a/typeguard/importhook.py +++ b/typeguard/importhook.py @@ -1,5 +1,4 @@ import ast -import re import sys from importlib.machinery import SourceFileLoader from importlib.abc import MetaPathFinder @@ -94,7 +93,7 @@ class TypeguardFinder(MetaPathFinder): """ def __init__(self, packages, original_pathfinder): - self._package_exprs = [re.compile(r'^%s\.?' % pkg) for pkg in packages] + self.packages = packages self._original_pathfinder = original_pathfinder def find_spec(self, fullname, path=None, target=None): @@ -113,8 +112,8 @@ def should_instrument(self, module_name: str) -> bool: :param module_name: full name of the module that is about to be imported (e.g. ``xyz.abc``) """ - for expr in self._package_exprs: - if expr.match(module_name): + for package in self.packages: + if module_name == package or module_name.startswith(package + '.'): return True return False