diff --git a/piptools/repositories/local.py b/piptools/repositories/local.py index 99d990296..2c873f42d 100644 --- a/piptools/repositories/local.py +++ b/piptools/repositories/local.py @@ -5,9 +5,9 @@ from pip._internal.commands.install import InstallCommand from pip._internal.index.package_finder import PackageFinder from pip._internal.models.candidate import InstallationCandidate +from pip._internal.network.session import PipSession from pip._internal.req import InstallRequirement from pip._internal.utils.hashes import FAVORITE_HASH -from pip._vendor.requests import Session from piptools.utils import as_tuple, key_from_ireq, make_install_requirement @@ -59,7 +59,7 @@ def finder(self) -> PackageFinder: return self.repository.finder @property - def session(self) -> Session: + def session(self) -> PipSession: return self.repository.session @property diff --git a/piptools/resolver.py b/piptools/resolver.py index f9d3d1250..bb85a20e6 100644 --- a/piptools/resolver.py +++ b/piptools/resolver.py @@ -5,6 +5,7 @@ from itertools import chain, count, groupby from typing import ( Any, + Container, DefaultDict, Dict, Iterable, @@ -176,22 +177,13 @@ def resolve_hashes( def _filter_out_unsafe_constraints( self, ireqs: Set[InstallRequirement], - reverse_dependencies: Dict[str, Set[str]], + unsafe_packages: Container[str], ) -> None: """ Remove from a given set of ``InstallRequirement``'s unsafe constraints. - - Reverse_dependencies is used to filter out packages that are only - required by unsafe packages. This logic is incomplete, as it would - fail to filter sub-sub-dependencies of unsafe packages. None of the - UNSAFE_PACKAGES currently have any dependencies at all (which makes - sense for installation tools) so this seems sufficient. """ for req in ireqs.copy(): - required_by = reverse_dependencies.get(req.name.lower(), set()) - if req.name in UNSAFE_PACKAGES or ( - required_by and all(name in UNSAFE_PACKAGES for name in required_by) - ): + if req.name in unsafe_packages: self.unsafe_constraints.add(req) ireqs.remove(req) @@ -206,6 +198,7 @@ def __init__( prereleases: Optional[bool] = False, clear_caches: bool = False, allow_unsafe: bool = False, + unsafe_packages: Optional[Set[str]] = None, ) -> None: """ This class resolves a given set of constraints (a collection of @@ -220,6 +213,7 @@ def __init__( self.clear_caches = clear_caches self.allow_unsafe = allow_unsafe self.unsafe_constraints: Set[InstallRequirement] = set() + self.unsafe_packages = unsafe_packages or UNSAFE_PACKAGES options = self.repository.options if "legacy-resolver" not in options.deprecated_features_enabled: @@ -281,7 +275,7 @@ def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]: if not self.allow_unsafe: self._filter_out_unsafe_constraints( ireqs=results, - reverse_dependencies=self.reverse_dependencies(results), + unsafe_packages=self.unsafe_packages, ) return results @@ -490,14 +484,6 @@ def _ireqs_of_dependencies( dependency_string, constraint=ireq.constraint, comes_from=ireq ) - def reverse_dependencies( - self, ireqs: Iterable[InstallRequirement] - ) -> Dict[str, Set[str]]: - non_editable = [ - ireq for ireq in ireqs if not (ireq.editable or is_url_requirement(ireq)) - ] - return self.dependency_cache.reverse_dependencies(non_editable) - class BacktrackingResolver(BaseResolver): """A wrapper for backtracking resolver.""" @@ -508,11 +494,13 @@ def __init__( existing_constraints: Dict[str, InstallRequirement], repository: BaseRepository, allow_unsafe: bool = False, + unsafe_packages: Optional[Set[str]] = None, **kwargs: Any, ) -> None: self.constraints = list(constraints) self.repository = repository self.allow_unsafe = allow_unsafe + self.unsafe_packages = unsafe_packages or UNSAFE_PACKAGES options = self.options = self.repository.options self.session = self.repository.session @@ -634,7 +622,7 @@ def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]: if not self.allow_unsafe: self._filter_out_unsafe_constraints( ireqs=result_ireqs, - reverse_dependencies=reverse_dependencies, + unsafe_packages=self.unsafe_packages, ) return result_ireqs diff --git a/piptools/scripts/compile.py b/piptools/scripts/compile.py index 0c98c6b2d..28e16abf2 100755 --- a/piptools/scripts/compile.py +++ b/piptools/scripts/compile.py @@ -245,6 +245,12 @@ def _get_default_option(option_name: str) -> Any: default=True, help="Add options to generated file", ) +@click.option( + "--unsafe-package", + multiple=True, + help="Specify a package to consider unsafe; may be used more than once. " + f"Replaces default unsafe packages: {', '.join(sorted(UNSAFE_PACKAGES))}", +) def cli( ctx: click.Context, verbose: int, @@ -279,6 +285,7 @@ def cli( resolver_name: str, emit_index_url: bool, emit_options: bool, + unsafe_package: Tuple[str, ...], ) -> None: """Compiles requirements.txt from requirements.in specs.""" log.verbosity = verbose - quiet @@ -483,6 +490,7 @@ def cli( cache=DependencyCache(cache_dir), clear_caches=rebuild, allow_unsafe=allow_unsafe, + unsafe_packages=set(unsafe_package), ) results = resolver.resolve(max_rounds=max_rounds) hashes = resolver.resolve_hashes(results) if generate_hashes else None diff --git a/tests/test_cli_compile.py b/tests/test_cli_compile.py index 359ce3f1a..fc2f76c71 100644 --- a/tests/test_cli_compile.py +++ b/tests/test_cli_compile.py @@ -1274,10 +1274,10 @@ def test_annotate_option(pip_conf, runner, options, expected): "--no-allow-unsafe", dedent( """\ + small-fake-a==0.1 small-fake-b==0.3 # The following packages are considered to be unsafe in a requirements file: - # small-fake-a # small-fake-with-deps """ ), @@ -1287,10 +1287,10 @@ def test_annotate_option(pip_conf, runner, options, expected): None, dedent( """\ + small-fake-a==0.1 small-fake-b==0.3 # The following packages are considered to be unsafe in a requirements file: - # small-fake-a # small-fake-with-deps """ ), diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 0e9b2fef1..bdb949d99 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -152,21 +152,19 @@ ), ( ["fake-piptools-test-with-unsafe-deps==0.1"], - ["fake-piptools-test-with-unsafe-deps==0.1"], + [ + "appdirs==1.4.9 (from " + "setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)", + "fake-piptools-test-with-unsafe-deps==0.1", + "packaging==16.8 (from " + "setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)", + ], False, { ( "setuptools==34.0.0 (from " "fake-piptools-test-with-unsafe-deps==0.1)" ), - ( - "appdirs==1.4.9 (from " - "setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)" - ), - ( - "packaging==16.8 (from " - "setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)" - ), }, ), # Git URL requirement @@ -250,6 +248,55 @@ def test_resolver__allows_unsafe_deps( assert output == {str(line) for line in expected} +@pytest.mark.parametrize( + ( + "input", + "expected", + "unsafe_packages", + "unsafe_constraints", + ), + ( + ( + ["fake-piptools-test-with-pinned-deps==0.1"], + { + "fake-piptools-test-with-pinned-deps==0.1", + "pytz==2016.4 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)", + "billiard==3.3.0.23 (from " + "celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)", + "celery==3.1.18 (from fake-piptools-test-with-pinned-deps==0.1)", + "anyjson==0.3.3 (from " + "kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)", + "amqp==1.4.9 (from " + "kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)", + }, + {"kombu"}, + { + "kombu==3.0.35 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)", + }, + ), + ), +) +def test_resolver__custom_unsafe_deps( + resolver, + from_line, + input, + expected, + unsafe_packages, + unsafe_constraints, +): + input = [line if isinstance(line, tuple) else (line, False) for line in input] + input = [from_line(req[0], constraint=req[1]) for req in input] + resolver = resolver( + input, + unsafe_packages=unsafe_packages, + ) + output = resolver.resolve() + output = {str(line) for line in output} + + assert output == expected + assert {str(line) for line in resolver.unsafe_constraints} == unsafe_constraints + + def test_resolver__max_number_rounds_reached(resolver, from_line): """ Resolver should raise an exception if max round has been reached.