Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Exclude Package with custom unsafe packages #1509

Merged
merged 19 commits into from Jul 17, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions piptools/_compat/pip_compat.py
Expand Up @@ -3,18 +3,18 @@

import pip
from pip._internal.index.package_finder import PackageFinder
from pip._internal.network.session import PipSession
from pip._internal.req import InstallRequirement
from pip._internal.req import parse_requirements as _parse_requirements
from pip._internal.req.constructors import install_req_from_parsed_requirement
from pip._vendor.packaging.version import parse as parse_version
from pip._vendor.requests.sessions import Session

PIP_VERSION = tuple(map(int, parse_version(pip.__version__).base_version.split(".")))


def parse_requirements(
filename: str,
session: PipSession,
ssbarnea marked this conversation as resolved.
Show resolved Hide resolved
session: Session,
finder: Optional[PackageFinder] = None,
options: Optional[optparse.Values] = None,
constraint: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions piptools/repositories/local.py
Expand Up @@ -4,9 +4,9 @@

from pip._internal.index.package_finder import PackageFinder
from pip._internal.models.candidate import InstallationCandidate
from pip._internal.network.session import PipSession
from pip._internal.req import InstallRequirement
from pip._internal.utils.hashes import FAVORITE_HASH
from pip._vendor.requests import Session

from piptools.utils import as_tuple, key_from_ireq, make_install_requirement

Expand Down Expand Up @@ -58,7 +58,7 @@ def finder(self) -> PackageFinder:
return self.repository.finder

@property
def session(self) -> Session:
def session(self) -> PipSession:
return self.repository.session

def clear_caches(self) -> None:
Expand Down
31 changes: 21 additions & 10 deletions piptools/resolver.py
Expand Up @@ -113,6 +113,8 @@ def __init__(
prereleases: Optional[bool] = False,
clear_caches: bool = False,
allow_unsafe: bool = False,
unsafe_packages: Optional[Set[str]] = None,
allow_unsafe_recursive: bool = False,
) -> None:
"""
This class resolves a given set of constraints (a collection of
Expand All @@ -127,6 +129,8 @@ def __init__(
self.clear_caches = clear_caches
self.allow_unsafe = allow_unsafe
self.unsafe_constraints: Set[InstallRequirement] = set()
self.unsafe_packages = unsafe_packages or UNSAFE_PACKAGES
self.allow_unsafe_recursive = allow_unsafe_recursive

@property
def constraints(self) -> Set[InstallRequirement]:
Expand Down Expand Up @@ -189,17 +193,24 @@ def resolve(self, max_rounds: int = 10) -> Set[InstallRequirement]:
# Filter out unsafe requirements.
self.unsafe_constraints = set()
if not self.allow_unsafe:
# reverse_dependencies is used to filter out packages that are only
# required by unsafe packages. This logic is incomplete, as it would
# fail to filter sub-sub-dependencies of unsafe packages. None of the
# UNSAFE_PACKAGES currently have any dependencies at all (which makes
# sense for installation tools) so this seems sufficient.
reverse_dependencies = self.reverse_dependencies(results)
if not self.allow_unsafe_recursive:
# reverse_dependencies is used to filter out packages that are only
# required by unsafe packages. This logic is incomplete, as it would
# fail to filter sub-sub-dependencies of unsafe packages. None of the
# UNSAFE_PACKAGES currently have any dependencies at all (which makes
# sense for installation tools) so this seems sufficient.
reverse_dependencies = self.reverse_dependencies(results)

for req in results.copy():
required_by = reverse_dependencies.get(req.name.lower(), set())
if required_by and all(
name in self.unsafe_packages for name in required_by
):
self.unsafe_constraints.add(req)
results.remove(req)

for req in results.copy():
required_by = reverse_dependencies.get(req.name.lower(), set())
if req.name in UNSAFE_PACKAGES or (
required_by and all(name in UNSAFE_PACKAGES for name in required_by)
):
if req.name in self.unsafe_packages:
self.unsafe_constraints.add(req)
results.remove(req)

Expand Down
18 changes: 18 additions & 0 deletions piptools/scripts/compile.py
Expand Up @@ -236,6 +236,20 @@ def _get_default_option(option_name: str) -> Any:
default=True,
help="Add options to generated file",
)
@click.option(
"--unsafe-packages",
multiple=True,
help="List of packages to consider unsafe. Replaces "
ssbarnea marked this conversation as resolved.
Show resolved Hide resolved
f"{', '.join(sorted(UNSAFE_PACKAGES))}; may be used more than once",
)
@click.option(
"--allow-unsafe-recursive/--no-allow-unsafe-recursive",
is_flag=True,
default=False,
help="Determines whether dependencies that solely belong to unsafe packages "
"should be treated as unsafe. Default is false. This only covers direct dependencies. "
"Dependencies of dependencies of unsafe packages will not be marked unsafe.",
ssbarnea marked this conversation as resolved.
Show resolved Hide resolved
)
def cli(
ctx: click.Context,
verbose: int,
Expand Down Expand Up @@ -269,6 +283,8 @@ def cli(
pip_args_str: Optional[str],
emit_index_url: bool,
emit_options: bool,
unsafe_packages: Tuple[str, ...],
allow_unsafe_recursive: bool,
) -> None:
"""Compiles requirements.txt from requirements.in specs."""
log.verbosity = verbose - quiet
Expand Down Expand Up @@ -462,6 +478,8 @@ def cli(
cache=DependencyCache(cache_dir),
clear_caches=rebuild,
allow_unsafe=allow_unsafe,
unsafe_packages=set(unsafe_packages),
ssbarnea marked this conversation as resolved.
Show resolved Hide resolved
allow_unsafe_recursive=allow_unsafe_recursive,
)
results = resolver.resolve(max_rounds=max_rounds)
hashes = resolver.resolve_hashes(results) if generate_hashes else None
Expand Down
72 changes: 72 additions & 0 deletions tests/test_resolver.py
Expand Up @@ -249,6 +249,78 @@ def test_resolver__allows_unsafe_deps(
assert output == {str(line) for line in expected}


@pytest.mark.parametrize(
(
"input",
"expected",
"unsafe_packages",
"allow_unsafe_recursive",
"unsafe_constraints",
),
(
(
["fake-piptools-test-with-pinned-deps==0.1"],
{
"fake-piptools-test-with-pinned-deps==0.1",
"pytz==2016.4 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"billiard==3.3.0.23 (from "
"celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"celery==3.1.18 (from fake-piptools-test-with-pinned-deps==0.1)",
},
{"kombu"},
False,
{
"kombu==3.0.35 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"anyjson==0.3.3 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"amqp==1.4.9 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
},
),
(
["fake-piptools-test-with-pinned-deps==0.1"],
{
"fake-piptools-test-with-pinned-deps==0.1",
"pytz==2016.4 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"billiard==3.3.0.23 (from "
"celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"celery==3.1.18 (from fake-piptools-test-with-pinned-deps==0.1)",
"anyjson==0.3.3 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
"amqp==1.4.9 (from "
"kombu==3.0.35->celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
},
{"kombu"},
True,
{
"kombu==3.0.35 (from celery==3.1.18->fake-piptools-test-with-pinned-deps==0.1)",
},
),
),
)
def test_resolver__custom_unsafe_deps(
resolver,
from_line,
input,
expected,
unsafe_packages,
allow_unsafe_recursive,
unsafe_constraints,
):
input = [line if isinstance(line, tuple) else (line, False) for line in input]
input = [from_line(req[0], constraint=req[1]) for req in input]
resolver = resolver(
input,
unsafe_packages=unsafe_packages,
allow_unsafe_recursive=allow_unsafe_recursive,
)
output = resolver.resolve()
output = {str(line) for line in output}

assert output == expected
assert {str(line) for line in resolver.unsafe_constraints} == unsafe_constraints


def test_resolver__max_number_rounds_reached(resolver, from_line):
"""
Resolver should raise an exception if max round has been reached.
Expand Down