diff --git a/packaging/specifiers.py b/packaging/specifiers.py index 6e5d4d59..a2d51b04 100644 --- a/packaging/specifiers.py +++ b/packaging/specifiers.py @@ -57,7 +57,8 @@ def __eq__(self, other: object) -> bool: objects are equal. """ - @abc.abstractproperty + @property + @abc.abstractmethod def prereleases(self) -> Optional[bool]: """ Returns whether or not pre-releases as a whole are allowed by this @@ -724,7 +725,10 @@ def __contains__(self, item: UnparsedVersion) -> bool: return self.contains(item) def contains( - self, item: UnparsedVersion, prereleases: Optional[bool] = None + self, + item: UnparsedVersion, + prereleases: Optional[bool] = None, + installed: Optional[bool] = None, ) -> bool: # Ensure that our item is a Version or LegacyVersion instance. @@ -746,6 +750,9 @@ def contains( if not prereleases and item.is_prerelease: return False + if installed and item.is_prerelease: + item = parse(item.base_version) + # We simply dispatch to the underlying specs here to make sure that the # given version is contained within all of them. # Note: This use of all() here means that an empty set of specifiers diff --git a/tests/test_specifiers.py b/tests/test_specifiers.py index 41a3f829..5949ebf6 100644 --- a/tests/test_specifiers.py +++ b/tests/test_specifiers.py @@ -819,6 +819,15 @@ def test_specifier_contains_prereleases(self): assert spec.contains("1.0.dev1") assert not spec.contains("1.0.dev1", prereleases=False) + def test_specifier_contains_installed_prereleases(self): + spec = SpecifierSet("~=1.0") + assert not spec.contains("1.0.0.dev1", installed=True) + assert spec.contains("1.0.0.dev1", prereleases=True, installed=True) + + spec = SpecifierSet("~=1.0", prereleases=True) + assert spec.contains("1.0.0.dev1", installed=True) + assert not spec.contains("1.0.0.dev1", prereleases=False, installed=False) + @pytest.mark.parametrize( ("specifier", "specifier_prereleases", "prereleases", "input", "expected"), [