Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatible version specifier incorrectly strip trailing '0' #493

Merged
merged 7 commits into from Jan 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion packaging/specifiers.py
Expand Up @@ -119,7 +119,11 @@ def __str__(self) -> str:

@property
def _canonical_spec(self) -> Tuple[str, str]:
return self._spec[0], canonicalize_version(self._spec[1])
canonical_version = canonicalize_version(
self._spec[1],
strip_trailing_zero=(self._spec[0] != "~="),
)
return self._spec[0], canonical_version

def __hash__(self) -> int:
return hash(self._canonical_spec)
Expand Down
11 changes: 8 additions & 3 deletions packaging/utils.py
Expand Up @@ -35,7 +35,9 @@ def canonicalize_name(name: str) -> NormalizedName:
return cast(NormalizedName, value)


def canonicalize_version(version: Union[Version, str]) -> str:
def canonicalize_version(
version: Union[Version, str], *, strip_trailing_zero: bool = True
) -> str:
"""
This is very similar to Version.__str__, but has one subtle difference
with the way it handles the release segment.
Expand All @@ -56,8 +58,11 @@ def canonicalize_version(version: Union[Version, str]) -> str:
parts.append(f"{parsed.epoch}!")

# Release segment
# NB: This strips trailing '.0's to normalize
parts.append(re.sub(r"(\.0)+$", "", ".".join(str(x) for x in parsed.release)))
release_segment = ".".join(str(x) for x in parsed.release)
if strip_trailing_zero:
# NB: This strips trailing '.0's to normalize
release_segment = re.sub(r"(\.0)+$", "", release_segment)
parts.append(release_segment)

# Pre-release
if parsed.pre is not None:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_specifiers.py
Expand Up @@ -630,6 +630,12 @@ def test_iteration(self, spec, expected_items):
items = {str(item) for item in spec}
assert items == set(expected_items)

def test_specifier_equal_for_compatible_operator(self):
assert Specifier("~=1.18.0") != Specifier("~=1.18")
kasium marked this conversation as resolved.
Show resolved Hide resolved

def test_specifier_hash_for_compatible_operator(self):
assert hash(Specifier("~=1.18.0")) != hash(Specifier("~=1.18"))


class TestLegacySpecifier:
def test_legacy_specifier_is_deprecated(self):
Expand Down Expand Up @@ -996,3 +1002,8 @@ def test_comparison_non_specifier(self):
)
def test_comparison_ignores_local(self, version, specifier, expected):
assert (Version(version) in SpecifierSet(specifier)) == expected

def test_contains_with_compatible_operator(self):
combination = SpecifierSet("~=1.18.0") & SpecifierSet("~=1.18")
assert "1.19.5" not in combination
assert "1.18.0" in combination
5 changes: 5 additions & 0 deletions tests/test_utils.py
Expand Up @@ -56,6 +56,11 @@ def test_canonicalize_version(version, expected):
assert canonicalize_version(version) == expected


@pytest.mark.parametrize(("version"), ["1.4.0", "1.0"])
def test_canonicalize_version_no_strip_trailing_zero(version):
assert canonicalize_version(version, strip_trailing_zero=False) == version


@pytest.mark.parametrize(
("filename", "name", "version", "build", "tags"),
[
Expand Down