Skip to content

Commit

Permalink
Support Exclude Package with custom unsafe packages (#1509)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Tugushev <albert@tugushev.ru>
Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com>
Co-authored-by: Sorin Sbarnea <ssbarnea@redhat.com>
  • Loading branch information
4 people committed Jul 17, 2022
1 parent 1e05d00 commit 86aa4bd
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 34 deletions.
4 changes: 2 additions & 2 deletions piptools/repositories/local.py
Expand Up @@ -5,9 +5,9 @@
from pip._internal.commands.install import InstallCommand
from pip._internal.index.package_finder import PackageFinder
from pip._internal.models.candidate import InstallationCandidate
from pip._internal.network.session import PipSession
from pip._internal.req import InstallRequirement
from pip._internal.utils.hashes import FAVORITE_HASH
from pip._vendor.requests import Session

from piptools.utils import as_tuple, key_from_ireq, make_install_requirement

Expand Down Expand Up @@ -59,7 +59,7 @@ def finder(self) -> PackageFinder:
return self.repository.finder

@property
def session(self) -> Session:
def session(self) -> PipSession:
return self.repository.session

@property
Expand Down
30 changes: 9 additions & 21 deletions piptools/resolver.py
Expand Up @@ -5,6 +5,7 @@
from itertools import chain, count, groupby
from typing import (
Any,
Container,
DefaultDict,
Dict,
Iterable,
Expand Down Expand Up @@ -176,22 +177,13 @@ def resolve_hashes(
def _filter_out_unsafe_constraints(
self,
ireqs: Set[InstallRequirement],
reverse_dependencies: Dict[str, Set[str]],
unsafe_packages: Container[str],
) -> None:
"""
Remove from a given set of ``InstallRequirement``'s unsafe constraints.
Reverse_dependencies is used to filter out packages that are only
required by unsafe packages. This logic is incomplete, as it would
fail to filter sub-sub-dependencies of unsafe packages. None of the
UNSAFE_PACKAGES currently have any dependencies at all (which makes
sense for installation tools) so this seems sufficient.
"""
for req in ireqs.copy():
required_by = reverse_dependencies.get(req.name.lower(), set())
if req.name in UNSAFE_PACKAGES or (
required_by and all(name in UNSAFE_PACKAGES for name in required_by)
):
if req.name in unsafe_packages:
self.unsafe_constraints.add(req)
ireqs.remove(req)

Expand All @@ -206,6 +198,7 @@ def __init__(
prereleases: Optional[bool] = False,
clear_caches: bool = False,
allow_unsafe: bool = False,
unsafe_packages: Optional[Set[str]] = None,
) -> None:
"""
This class resolves a given set of constraints (a collection of
Expand All @@ -220,6 +213,7 @@ def __init__(
self.clear_caches = clear_caches
self.allow_unsafe = allow_unsafe
self.unsafe_constraints: Set[InstallRequirement] = set()
self.unsafe_packages = unsafe_packages or UNSAFE_PACKAGES

options = self.repository.options
if "legacy-resolver" not in options.deprecated_features_enabled:
Expand Down Expand Up @@ -281,7 +275,7 @@ def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]:
if not self.allow_unsafe:
self._filter_out_unsafe_constraints(
ireqs=results,
reverse_dependencies=self.reverse_dependencies(results),
unsafe_packages=self.unsafe_packages,
)

return results
Expand Down Expand Up @@ -490,14 +484,6 @@ def _ireqs_of_dependencies(
dependency_string, constraint=ireq.constraint, comes_from=ireq
)

def reverse_dependencies(
self, ireqs: Iterable[InstallRequirement]
) -> Dict[str, Set[str]]:
non_editable = [
ireq for ireq in ireqs if not (ireq.editable or is_url_requirement(ireq))
]
return self.dependency_cache.reverse_dependencies(non_editable)


class BacktrackingResolver(BaseResolver):
"""A wrapper for backtracking resolver."""
Expand All @@ -508,11 +494,13 @@ def __init__(
existing_constraints: Dict[str, InstallRequirement],
repository: BaseRepository,
allow_unsafe: bool = False,
unsafe_packages: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
self.constraints = list(constraints)
self.repository = repository
self.allow_unsafe = allow_unsafe
self.unsafe_packages = unsafe_packages or UNSAFE_PACKAGES

options = self.options = self.repository.options
self.session = self.repository.session
Expand Down Expand Up @@ -634,7 +622,7 @@ def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]:
if not self.allow_unsafe:
self._filter_out_unsafe_constraints(
ireqs=result_ireqs,
reverse_dependencies=reverse_dependencies,
unsafe_packages=self.unsafe_packages,
)

return result_ireqs
Expand Down
8 changes: 8 additions & 0 deletions piptools/scripts/compile.py
Expand Up @@ -245,6 +245,12 @@ def _get_default_option(option_name: str) -> Any:
default=True,
help="Add options to generated file",
)
@click.option(
"--unsafe-package",
multiple=True,
help="Specify a package to consider unsafe; may be used more than once. "
f"Replaces default unsafe packages: {', '.join(sorted(UNSAFE_PACKAGES))}",
)
def cli(
ctx: click.Context,
verbose: int,
Expand Down Expand Up @@ -279,6 +285,7 @@ def cli(
resolver_name: str,
emit_index_url: bool,
emit_options: bool,
unsafe_package: Tuple[str, ...],
) -> None:
"""Compiles requirements.txt from requirements.in specs."""
log.verbosity = verbose - quiet
Expand Down Expand Up @@ -483,6 +490,7 @@ def cli(
cache=DependencyCache(cache_dir),
clear_caches=rebuild,
allow_unsafe=allow_unsafe,
unsafe_packages=set(unsafe_package),
)
results = resolver.resolve(max_rounds=max_rounds)
hashes = resolver.resolve_hashes(results) if generate_hashes else None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli_compile.py
Expand Up @@ -1274,10 +1274,10 @@ def test_annotate_option(pip_conf, runner, options, expected):
"--no-allow-unsafe",
dedent(
"""\
small-fake-a==0.1
small-fake-b==0.3
# The following packages are considered to be unsafe in a requirements file:
# small-fake-a
# small-fake-with-deps
"""
),
Expand All @@ -1287,10 +1287,10 @@ def test_annotate_option(pip_conf, runner, options, expected):
None,
dedent(
"""\
small-fake-a==0.1
small-fake-b==0.3
# The following packages are considered to be unsafe in a requirements file:
# small-fake-a
# small-fake-with-deps
"""
),
Expand Down
65 changes: 56 additions & 9 deletions tests/test_resolver.py
Expand Up @@ -152,21 +152,19 @@
),
(
["fake-piptools-test-with-unsafe-deps==0.1"],
["fake-piptools-test-with-unsafe-deps==0.1"],
[
"appdirs==1.4.9 (from "
"setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)",
"fake-piptools-test-with-unsafe-deps==0.1",
"packaging==16.8 (from "
"setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)",
],
False,
{
(
"setuptools==34.0.0 (from "
"fake-piptools-test-with-unsafe-deps==0.1)"
),
(
"appdirs==1.4.9 (from "
"setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)"
),
(
"packaging==16.8 (from "
"setuptools==34.0.0->fake-piptools-test-with-unsafe-deps==0.1)"
),
},
),
# Git URL requirement
Expand Down Expand Up @@ -250,6 +248,55 @@ def test_resolver__allows_unsafe_deps(
assert output == {str(line) for line in expected}


@pytest.mark.parametrize(
(
"input",
"expected",
"unsafe_packages",
"unsafe_constraints",
),
(
(
["fake-piptools-test-with-pinned-deps==0.1"],
{
"fake-piptools-test-with-pinned-deps==0.1",
"pytz==2016.4 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"billiard==3.3.0.23 (from "
"celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"celery==3.1.18 (from fake-piptools-test-with-pinned-deps==0.1)",
"anyjson==0.3.3 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"amqp==1.4.9 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
},
{"kombu"},
{
"kombu==3.0.35 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
},
),
),
)
def test_resolver__custom_unsafe_deps(
resolver,
from_line,
input,
expected,
unsafe_packages,
unsafe_constraints,
):
input = [line if isinstance(line, tuple) else (line, False) for line in input]
input = [from_line(req[0], constraint=req[1]) for req in input]
resolver = resolver(
input,
unsafe_packages=unsafe_packages,
)
output = resolver.resolve()
output = {str(line) for line in output}

assert output == expected
assert {str(line) for line in resolver.unsafe_constraints} == unsafe_constraints


def test_resolver__max_number_rounds_reached(resolver, from_line):
"""
Resolver should raise an exception if max round has been reached.
Expand Down

0 comments on commit 86aa4bd

Please sign in to comment.