Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for the writer module #1310

Merged
merged 1 commit into from
Feb 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .bandit
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[bandit]
exclude: tests,.tox,.eggs,.venv,.git
skips: B101
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ repos:
rev: 1.7.0
hooks:
- id: bandit
args: [--ini, .bandit]
exclude: ^tests/
7 changes: 4 additions & 3 deletions piptools/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import logging
import sys
from typing import Any

import click

Expand All @@ -17,7 +18,7 @@ def __init__(self, verbosity: int = 0, indent_width: int = 2):
self.current_indent = 0
self._indent_width = indent_width

def log(self, message, *args, **kwargs):
def log(self, message: str, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("err", True)
prefix = " " * self.current_indent
click.secho(prefix + message, *args, **kwargs)
Expand All @@ -26,11 +27,11 @@ def debug(self, *args, **kwargs):
if self.verbosity >= 1:
self.log(*args, **kwargs)

def info(self, *args, **kwargs):
def info(self, *args: Any, **kwargs: Any) -> None:
if self.verbosity >= 0:
self.log(*args, **kwargs)

def warning(self, *args, **kwargs):
def warning(self, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("fg", "yellow")
self.log(*args, **kwargs)

Expand Down
27 changes: 18 additions & 9 deletions piptools/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import shlex
from collections import OrderedDict
from itertools import chain
from typing import Any, Iterable, Iterator, Optional, Set

from click import style
import click
from click.utils import LazyFile
from pip._internal.req import InstallRequirement
from pip._internal.req.constructors import install_req_from_line
from pip._internal.utils.misc import redact_auth_from_url
from pip._internal.vcs import is_url
from pip._vendor.packaging.markers import Marker

UNSAFE_PACKAGES = {"setuptools", "distribute", "pip"}
COMPILE_EXCLUDE_OPTIONS = {
Expand All @@ -21,29 +24,29 @@
}


def key_from_ireq(ireq):
def key_from_ireq(ireq: InstallRequirement) -> str:
"""Get a standardized key for an InstallRequirement."""
if ireq.req is None and ireq.link is not None:
return str(ireq.link)
else:
return key_from_req(ireq.req)


def key_from_req(req):
def key_from_req(req: InstallRequirement) -> str:
"""Get an all-lowercase version of the requirement's name."""
if hasattr(req, "key"):
# from pkg_resources, such as installed dists for pip-sync
key = req.key
else:
# from packaging, such as install requirements from requirements.txt
key = req.name

assert isinstance(key, str)
key = key.replace("_", "-").lower()
return key


def comment(text: str) -> str:
return style(text, fg="green")
return click.style(text, fg="green")


def make_install_requirement(name, version, extras, constraint=False):
Expand All @@ -58,15 +61,19 @@ def make_install_requirement(name, version, extras, constraint=False):
)


def is_url_requirement(ireq):
def is_url_requirement(ireq: InstallRequirement) -> bool:
"""
Return True if requirement was specified as a path or URL.
ireq.original_link will have been set by InstallRequirement.__init__
"""
return bool(ireq.original_link)


def format_requirement(ireq, marker=None, hashes=None):
def format_requirement(
ireq: InstallRequirement,
marker: Optional[Marker] = None,
hashes: Optional[Set[str]] = None,
) -> str:
"""
Generic formatter for pretty printing InstallRequirements to the terminal
in a less verbose way than using its `__str__` method.
Expand Down Expand Up @@ -223,7 +230,7 @@ def keyval(v):
return dict(lut)


def dedup(iterable):
def dedup(iterable: Iterable[Any]) -> Iterator[Any]:
"""Deduplicate an iterable object like iter(set(iterable)) but
order-preserved.
"""
Expand Down Expand Up @@ -253,7 +260,7 @@ def get_hashes_from_ireq(ireq):
return result


def get_compile_command(click_ctx):
def get_compile_command(click_ctx: click.Context) -> str:
"""
Returns a normalized compile command depending on cli context.

Expand Down Expand Up @@ -285,6 +292,8 @@ def get_compile_command(click_ctx):
right_args.extend([shlex.quote(val) for val in value])
continue

assert isinstance(option, click.Option)

# Get the latest option name (usually it'll be a long name)
option_long_name = option.opts[-1]

Expand Down
94 changes: 58 additions & 36 deletions piptools/writer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import re
from itertools import chain
from typing import BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, Tuple

from click import unstyle
from click.core import Context
from pip._internal.models.format_control import FormatControl
from pip._internal.req.req_install import InstallRequirement
from pip._vendor.packaging.markers import Marker

from .logging import log
from .utils import (
Expand Down Expand Up @@ -39,7 +44,7 @@
strip_comes_from_line_re = re.compile(r" \(line \d+\)$")


def _comes_from_as_string(ireq):
def _comes_from_as_string(ireq: InstallRequirement) -> str:
if isinstance(ireq.comes_from, str):
return strip_comes_from_line_re.sub("", ireq.comes_from)
return key_from_ireq(ireq.comes_from)
Expand All @@ -48,22 +53,22 @@ def _comes_from_as_string(ireq):
class OutputWriter:
def __init__(
self,
dst_file,
click_ctx,
dry_run,
emit_header,
emit_index_url,
emit_trusted_host,
annotate,
generate_hashes,
default_index_url,
index_urls,
trusted_hosts,
format_control,
allow_unsafe,
find_links,
emit_find_links,
):
dst_file: BinaryIO,
click_ctx: Context,
dry_run: bool,
emit_header: bool,
emit_index_url: bool,
emit_trusted_host: bool,
annotate: bool,
generate_hashes: bool,
default_index_url: str,
index_urls: Sequence[str],
trusted_hosts: Sequence[str],
format_control: FormatControl,
allow_unsafe: bool,
find_links: List[str],
emit_find_links: bool,
) -> None:
self.dst_file = dst_file
self.click_ctx = click_ctx
self.dry_run = dry_run
Expand All @@ -80,10 +85,10 @@ def __init__(
self.find_links = find_links
self.emit_find_links = emit_find_links

def _sort_key(self, ireq):
def _sort_key(self, ireq: InstallRequirement) -> Tuple[bool, str]:
return (not ireq.editable, str(ireq.req).lower())

def write_header(self):
def write_header(self) -> Iterator[str]:
if self.emit_header:
yield comment("#")
yield comment("# This file is autogenerated by pip-compile")
Expand All @@ -95,31 +100,31 @@ def write_header(self):
yield comment(f"# {compile_command}")
yield comment("#")

def write_index_options(self):
def write_index_options(self) -> Iterator[str]:
if self.emit_index_url:
for index, index_url in enumerate(dedup(self.index_urls)):
if index_url.rstrip("/") == self.default_index_url:
continue
flag = "--index-url" if index == 0 else "--extra-index-url"
yield f"{flag} {index_url}"

def write_trusted_hosts(self):
def write_trusted_hosts(self) -> Iterator[str]:
if self.emit_trusted_host:
for trusted_host in dedup(self.trusted_hosts):
yield f"--trusted-host {trusted_host}"

def write_format_controls(self):
def write_format_controls(self) -> Iterator[str]:
for nb in dedup(sorted(self.format_control.no_binary)):
yield f"--no-binary {nb}"
for ob in dedup(sorted(self.format_control.only_binary)):
yield f"--only-binary {ob}"

def write_find_links(self):
def write_find_links(self) -> Iterator[str]:
if self.emit_find_links:
for find_link in dedup(self.find_links):
yield f"--find-links {find_link}"

def write_flags(self):
def write_flags(self) -> Iterator[str]:
emitted = False
for line in chain(
self.write_index_options(),
Expand All @@ -132,9 +137,15 @@ def write_flags(self):
if emitted:
yield ""

def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=None):
def _iter_lines(
self,
results: Set[InstallRequirement],
unsafe_requirements: Optional[Set[InstallRequirement]] = None,
markers: Optional[Dict[str, Marker]] = None,
hashes: Optional[Dict[InstallRequirement, Set[str]]] = None,
) -> Iterator[str]:
# default values
unsafe_requirements = unsafe_requirements or []
unsafe_requirements = unsafe_requirements or set()
markers = markers or {}
hashes = hashes or {}

Expand All @@ -160,8 +171,7 @@ def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=No
packages = {r for r in results if r.name not in UNSAFE_PACKAGES}

if packages:
packages = sorted(packages, key=self._sort_key)
for ireq in packages:
for ireq in sorted(packages, key=self._sort_key):
if has_hashes and not hashes.get(ireq):
yield MESSAGE_UNHASHED_PACKAGE
warn_uninstallable = True
Expand All @@ -172,7 +182,6 @@ def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=No
yielded = True

if unsafe_requirements:
unsafe_requirements = sorted(unsafe_requirements, key=self._sort_key)
yield ""
yielded = True
if has_hashes and not self.allow_unsafe:
Expand All @@ -181,7 +190,7 @@ def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=No
else:
yield MESSAGE_UNSAFE_PACKAGES

for ireq in unsafe_requirements:
for ireq in sorted(unsafe_requirements, key=self._sort_key):
ireq_key = key_from_ireq(ireq)
if not self.allow_unsafe:
yield comment(f"# {ireq_key}")
Expand All @@ -198,15 +207,26 @@ def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=No
if warn_uninstallable:
log.warning(MESSAGE_UNINSTALLABLE)

def write(self, results, unsafe_requirements, markers, hashes):
def write(
self,
results: Set[InstallRequirement],
unsafe_requirements: Set[InstallRequirement],
markers: Dict[str, Marker],
hashes: Optional[Dict[InstallRequirement, Set[str]]],
) -> None:

for line in self._iter_lines(results, unsafe_requirements, markers, hashes):
log.info(line)
if not self.dry_run:
self.dst_file.write(unstyle(line).encode())
self.dst_file.write(os.linesep.encode())

def _format_requirement(self, ireq, marker=None, hashes=None):
def _format_requirement(
self,
ireq: InstallRequirement,
marker: Optional[Marker] = None,
hashes: Optional[Dict[InstallRequirement, Set[str]]] = None,
) -> str:
ireq_hashes = (hashes if hashes is not None else {}).get(ireq)

line = format_requirement(ireq, marker=marker, hashes=ireq_hashes)
Expand All @@ -224,15 +244,17 @@ def _format_requirement(self, ireq, marker=None, hashes=None):
}
elif ireq.comes_from:
required_by.add(_comes_from_as_string(ireq))

if required_by:
required_by = sorted(required_by)
if len(required_by) == 1:
source = required_by[0]
sorted_required_by = sorted(required_by)
if len(sorted_required_by) == 1:
source = sorted_required_by[0]
annotation = " # via " + source
else:
annotation_lines = [" # via"]
for source in required_by:
for source in sorted_required_by:
annotation_lines.append(" # " + source)
annotation = "\n".join(annotation_lines)
line = f"{line}\n{comment(annotation)}"

return line