Skip to content

Commit

Permalink
Add type hints for the resolver module (#1316)
Browse files Browse the repository at this point in the history
Additionally refactored `lookup_table()`:
- removed unused functionality
- made function type-safe
- moved doctests to tests/test_utils.py and add more tests

Co-authored-by: Sviatoslav Sydorenko <wk.cvs.github@sydorenko.org.ua>
Co-authored-by: Jon Dufresne <jon.dufresne@gmail.com>
  • Loading branch information
3 people committed Feb 9, 2021
1 parent 6629a8e commit f60fb00
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 108 deletions.
10 changes: 7 additions & 3 deletions piptools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import platform
import sys
from typing import Dict, List, Optional, Tuple, cast
from typing import Dict, Iterable, List, Optional, Set, Tuple, cast

from pip._internal.req import InstallRequirement
from pip._vendor.packaging.requirements import Requirement
Expand Down Expand Up @@ -129,7 +129,9 @@ def __setitem__(self, ireq: InstallRequirement, values: List[str]) -> None:
self.cache[pkgname][pkgversion_and_extras] = values
self.write_cache()

def reverse_dependencies(self, ireqs):
def reverse_dependencies(
self, ireqs: Iterable[InstallRequirement]
) -> Dict[str, Set[str]]:
"""
Returns a lookup table of reverse dependencies for all the given ireqs.
Expand All @@ -141,7 +143,9 @@ def reverse_dependencies(self, ireqs):
ireqs_as_cache_values = [self.as_cache_key(ireq) for ireq in ireqs]
return self._reverse_dependencies(ireqs_as_cache_values)

def _reverse_dependencies(self, cache_keys):
def _reverse_dependencies(
self, cache_keys: Iterable[Tuple[str, str]]
) -> Dict[str, Set[str]]:
"""
Returns a lookup table of reverse dependencies for all the given cache keys.
Expand Down
13 changes: 7 additions & 6 deletions piptools/repositories/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import Iterator, Set
from typing import Iterator, Optional, Set

from pip._internal.req import InstallRequirement
from pip._vendor.packaging.version import Version


class BaseRepository(metaclass=ABCMeta):
Expand All @@ -16,10 +15,12 @@ def freshen_build_caches(self) -> Iterator[None]:
"""Should start with fresh build/source caches."""

@abstractmethod
def find_best_match(self, ireq: InstallRequirement) -> Version:
def find_best_match(
self, ireq: InstallRequirement, prereleases: Optional[bool]
) -> InstallRequirement:
"""
Return a Version object that indicates the best match for the given
InstallRequirement according to the repository.
Returns a pinned InstallRequirement object that indicates the best match
for the given InstallRequirement according to the external repository.
"""

@abstractmethod
Expand All @@ -33,7 +34,7 @@ def get_dependencies(self, ireq: InstallRequirement) -> Set[InstallRequirement]:
@abstractmethod
def get_hashes(self, ireq: InstallRequirement) -> Set[str]:
"""
Given a pinned InstallRequire, returns a set of hashes that represent
Given a pinned InstallRequirement, returns a set of hashes that represent
all of the files for a given requirement. It is not acceptable for an
editable or unpinned requirement to be passed to this function.
"""
Expand Down
4 changes: 2 additions & 2 deletions piptools/repositories/pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def find_all_candidates(self, req_name):

def find_best_match(self, ireq, prereleases=None):
"""
Returns a Version object that indicates the best match for the given
InstallRequirement according to the external repository.
Returns a pinned InstallRequirement object that indicates the best match
for the given InstallRequirement according to the external repository.
"""
if ireq.editable or is_url_requirement(ireq):
return ireq # return itself as the best match
Expand Down
59 changes: 37 additions & 22 deletions piptools/resolver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import copy
from functools import partial
from itertools import chain, count, groupby
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple

import click
from pip._internal.req import InstallRequirement
from pip._internal.req.constructors import install_req_from_line
from pip._internal.req.req_tracker import update_env_context_manager

from piptools.cache import DependencyCache
from piptools.repositories.base import BaseRepository

from .logging import log
from .utils import (
UNSAFE_PACKAGES,
Expand All @@ -26,15 +30,16 @@ class RequirementSummary:
Summary of a requirement's properties for comparison purposes.
"""

def __init__(self, ireq: InstallRequirement):
def __init__(self, ireq: InstallRequirement) -> None:
self.req = ireq.req
self.key = key_from_ireq(ireq)
self.extras = frozenset(ireq.extras)
self.specifier = ireq.specifier

def __eq__(self, other: object) -> bool:
if not isinstance(other, RequirementSummary):
if not isinstance(other, self.__class__):
return NotImplemented

return (
self.key == other.key
and self.specifier == other.specifier
Expand All @@ -48,7 +53,9 @@ def __str__(self) -> str:
return repr((self.key, str(self.specifier), sorted(self.extras)))


def combine_install_requirements(repository, ireqs):
def combine_install_requirements(
repository: BaseRepository, ireqs: Iterable[InstallRequirement]
) -> InstallRequirement:
"""
Return a single install requirement that reflects a combination of
all the inputs.
Expand Down Expand Up @@ -103,34 +110,36 @@ def combine_install_requirements(repository, ireqs):
class Resolver:
def __init__(
self,
constraints,
repository,
cache,
prereleases=False,
clear_caches=False,
allow_unsafe=False,
):
constraints: Iterable[InstallRequirement],
repository: BaseRepository,
cache: DependencyCache,
prereleases: Optional[bool] = False,
clear_caches: bool = False,
allow_unsafe: bool = False,
) -> None:
"""
This class resolves a given set of constraints (a collection of
InstallRequirement objects) by consulting the given Repository and the
DependencyCache.
"""
self.our_constraints = set(constraints)
self.their_constraints = set()
self.their_constraints: Set[InstallRequirement] = set()
self.repository = repository
self.dependency_cache = cache
self.prereleases = prereleases
self.clear_caches = clear_caches
self.allow_unsafe = allow_unsafe
self.unsafe_constraints = set()
self.unsafe_constraints: Set[InstallRequirement] = set()

@property
def constraints(self):
def constraints(self) -> Set[InstallRequirement]:
return set(
self._group_constraints(chain(self.our_constraints, self.their_constraints))
)

def resolve_hashes(self, ireqs):
def resolve_hashes(
self, ireqs: Set[InstallRequirement]
) -> Dict[InstallRequirement, Set[str]]:
"""
Finds acceptable hashes for all of the given InstallRequirements.
"""
Expand All @@ -139,7 +148,7 @@ def resolve_hashes(self, ireqs):
with self.repository.allow_all_wheels(), log.indentation():
return {ireq: self.repository.get_hashes(ireq) for ireq in ireqs}

def resolve(self, max_rounds=10):
def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]:
"""
Finds concrete package versions for all the given InstallRequirements
and their recursive dependencies. The end result is a flat list of
Expand Down Expand Up @@ -196,7 +205,7 @@ def resolve(self, max_rounds=10):
# sense for installation tools) so this seems sufficient.
reverse_dependencies = self.reverse_dependencies(results)
for req in results.copy():
required_by = reverse_dependencies.get(req.name.lower(), [])
required_by = reverse_dependencies.get(req.name.lower(), set())
if req.name in UNSAFE_PACKAGES or (
required_by and all(name in UNSAFE_PACKAGES for name in required_by)
):
Expand All @@ -205,7 +214,9 @@ def resolve(self, max_rounds=10):

return results

def _group_constraints(self, constraints):
def _group_constraints(
self, constraints: Iterable[InstallRequirement]
) -> Iterator[InstallRequirement]:
"""
Groups constraints (remember, InstallRequirements!) by their key name,
and combining their SpecifierSets into a single InstallRequirement per
Expand Down Expand Up @@ -238,7 +249,7 @@ def _group_constraints(self, constraints):
):
yield combine_install_requirements(self.repository, ireqs)

def _resolve_one_round(self):
def _resolve_one_round(self) -> Tuple[bool, Set[InstallRequirement]]:
"""
Resolves one level of the current constraints, by finding the best
match for each package in the repository and adding all requirements
Expand Down Expand Up @@ -266,7 +277,7 @@ def _resolve_one_round(self):
log.debug("")
log.debug("Finding secondary dependencies:")

their_constraints = []
their_constraints: List[InstallRequirement] = []
with log.indentation():
for best_match in best_matches:
their_constraints.extend(self._iter_dependencies(best_match))
Expand Down Expand Up @@ -298,7 +309,7 @@ def _resolve_one_round(self):
self.their_constraints = theirs
return has_changed, best_matches

def get_best_match(self, ireq):
def get_best_match(self, ireq: InstallRequirement) -> InstallRequirement:
"""
Returns a (pinned or editable) InstallRequirement, indicating the best
match to use for the given InstallRequirement (in the form of an
Expand Down Expand Up @@ -341,7 +352,9 @@ def get_best_match(self, ireq):
best_match._source_ireqs = ireq._source_ireqs
return best_match

def _iter_dependencies(self, ireq):
def _iter_dependencies(
self, ireq: InstallRequirement
) -> Iterator[InstallRequirement]:
"""
Given a pinned, url, or editable InstallRequirement, collects all the
secondary dependencies for them, either by looking them up in a local
Expand Down Expand Up @@ -392,7 +405,9 @@ def _iter_dependencies(self, ireq):
dependency_string, constraint=ireq.constraint, comes_from=ireq
)

def reverse_dependencies(self, ireqs):
def reverse_dependencies(
self, ireqs: Iterable[InstallRequirement]
) -> Dict[str, Set[str]]:
non_editable = [
ireq for ireq in ireqs if not (ireq.editable or is_url_requirement(ireq))
]
Expand Down

0 comments on commit f60fb00

Please sign in to comment.