diff --git a/pre_commit/all_languages.py b/pre_commit/all_languages.py index 476bad9da..e4c8a929f 100644 --- a/pre_commit/all_languages.py +++ b/pre_commit/all_languages.py @@ -7,6 +7,7 @@ from pre_commit.languages import docker from pre_commit.languages import docker_image from pre_commit.languages import dotnet +from pre_commit.languages import download from pre_commit.languages import fail from pre_commit.languages import golang from pre_commit.languages import haskell @@ -30,6 +31,7 @@ 'docker': docker, 'docker_image': docker_image, 'dotnet': dotnet, + 'download': download, 'fail': fail, 'golang': golang, 'haskell': haskell, diff --git a/pre_commit/languages/download.py b/pre_commit/languages/download.py new file mode 100644 index 000000000..c7d72b7ac --- /dev/null +++ b/pre_commit/languages/download.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from base64 import standard_b64decode as b64decode +from base64 import standard_b64encode as b64encode +import contextlib +from dataclasses import dataclass +import hashlib +import os.path +from os import chmod +from pathlib import Path, PurePath +import platform +from typing import ( + AnyStr, + Bool, + Byte, + Generator, + IO, + Iterator, + Literal, + Sequence, + Tuple, +) +from urllib.parse import urlparse +from urllib.request import urlopen + +from pre_commit import lang_base +from pre_commit.envcontext import Var +from pre_commit.envcontext import PatchesT +from pre_commit.envcontext import envcontext + +from pre_commit.prefix import Prefix + +ENVIRONMENT_DIR = 'download' + +def get_env_patch(target_dir: str) -> PatchesT: + return ( + ('PATH', (target_dir, os.pathsep, Var('PATH'))), + ) + +@contextlib.contextmanager +def in_env(prefix: Prefix, version: str) -> Generator[None, None, None]: + envdir = lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version) + with envcontext(get_env_patch(envdir)): + yield + +@dataclass(frozen=True) +class Platform: + value: str + + def __post_init__(self) -> None: + os, cpu = self.parts + + if os not in {"linux", "darwin", "windows"}: + raise ValueError(f"invalid operating system `{os}`") + + if cpu not in {"amd64", "aarch64"}: + raise ValueError(f"invalid CPU `{cpu}`") + + @property + def parts(self) -> Tuple[str, str]: + first, second = self.value.split("/", 1) + return (first, second) + + @property + def os(self) -> Literal["linux", "darwin", "windows", "dragonfly", "freebsd"]: + os, _ = self.parts + return os # type: ignore + + @property + def cpu(self) -> Literal["arm64", "arm64be", "arm", "386", "amd64", "ppc", "ppc64", "ppc64le"]: + _, cpu = self.parts + return cpu # type: ignore + + @staticmethod + def host() -> Platform: + os = { + "Linux": "linux", + "Darwin": "darwin", + "Windows": "windows", + "DragonFly": "dragonfly", + "FreeBSD": "freebsd" + }[platform.system()] + cpu = { + "aarch64": "arm64", + "aarch64_be": "arm64be", + "arm": "arm", + "i386": "386", + "i686": "386", + "x86_64": "amd64", + "ppc": "ppc", + "ppc64": "ppc64", + "ppc64le": "ppc64le" + }[platform.machine()] + return Platform(f"{os}/{cpu}") + + def __eq__(self, other: Platform) -> Bool: + return other.os == self.os and other.cpu == self.cpu + +class ContentCorruptionError(Exception): + pass + + +class SRI: + def __init__(self, value: str): + self._value = value + self._algorithm, self._checksum = self.value.split("-", 1) + + assert self._algorithm in {"sha256", "sha384", "sha512"} + + if b64encode(b64decode(self.checksum)) != self.checksum: + raise ValueError("Invalid checksum string.") + if (self.algorithm == "sha256" and len(self.checksum) != 44) or \ + (self.algorithm == "sha384" and len(self.checksum) != 64) or \ + (self.algorithm == "sha384" and len(self.checksum) != 88): + raise ValueError(f"Invalid checksum string length for {self.algorithm}") + + self._hasher = hashlib.new(self.algorithm) + + @property + def value(self) -> str: + return self._value + + @property + def algorithm(self) -> Literal["sha256", "sha384", "sha512"]: + return self._algorithm + + @property + def checksum(self) -> str: + return self._checksum + + @property + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + def update(self, data: Byte) -> None: + return self._hasher.update(data) + + def check_content(self) -> None: + if self._hasher.hexdigest != self.checksum: + raise ContentCorruptionError("SRI checker cotent corruption.") + + +@dataclass(frozen=True) +class URI: + value: str + + def __post_init__(self) -> None: + url = urlparse(self.value) + if not all([url.scheme, url.netloc]): + raise ValueError(f"Invalid URI: {self.value}") + + def __str__(self) -> str: + return self.value + +@dataclass(frozen=True) +class Metadata: + value: str + + @property + def parts(self) -> Tuple[str, str, str, str]: + first, second, third, fourth = self.value.splitlines() + return (first, second, third, fourth) + + @property + def platform(self) -> Platform: + platform, _, _, _ = self.parts + return Platform(platform) + + @property + def sri(self) -> SRI: + _, sri, _, _ = self.parts + return SRI(sri) + + @property + def url(self) -> URI: + _, _, url, _ = self.parts + return url + + @property + def filename(self) -> PurePath: + _, _, _, path = self.parts + return PurePath(path) + +def io_buffer(io: IO[AnyStr], size: int = 4096) -> Iterator[AnyStr]: + """Works around not being able to use the walrus operator. + + :param io: the stream to read chunks from + :param size: the size of chunks to read + :returns: iterator of the chunks + """ + buffer = io.read(size) + while buffer: + yield buffer + buffer = io.read(size) + +def download(uri: URI, sri: SRI, filename: Path) -> None: + filename.parent.mkdir(...) + # do the download + hasher = hashlib.new(sri.algorithm) + with urlopen(str(uri)) as ws: + with filename.open("wb") as fp: + for buf in io_buffer(ws): + fp.write(buf) + sri.update(buf) + fp.flush() + sri.check_content() + chmod(filename) + + +def install_environment( + prefix: Prefix, + version: str, + additional_dependencies: Sequence[str], +) -> None: + host= Platform.host() + for dep in additional_dependencies: + m = Metadata(dep) + if host == m.platform: + envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version)) + filename = envdir / m.filename + download(m.uri, m.sri, filename ) + srisum = envdir / "health.srisum" + with srisum.open("w", encoding="utf8") as stream: + stream.write(f"{filename} {m.sri}\n") + return + raise KeyError(f"Failed to find platform `{host}` in `additional_dependencies`: {additional_dependencies}") + + +def health_check(prefix: Prefix, version: str) -> str | None: + envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version)) + srisum = envdir / "health.srisum" + with srisum.open(encoding="utf8") as stream: + for line in stream: + p, s = line.split(" ", 1) + filename = envdir / p + sri = SRI(s) + with filename.open("rb") as fp: + sri.update(fp) + sri.check_content() \ No newline at end of file diff --git a/tests/languages/download_test.py b/tests/languages/download_test.py new file mode 100644 index 000000000..80f646b5c --- /dev/null +++ b/tests/languages/download_test.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from pre_commit.languages import download +from testing.language_helpers import run_language + + +def test_download_hooks(tmp_path): + ret = run_language( + tmp_path, + download, + 'watch out for', + file_args=('bunnies',), + ) + assert ret == (1, b'watch out for\n\nbunnies\n') + +def test_download_dependencies(tmp_path): + ret = run_language( + tmp_path, + download, + 'hello', + deps=('linux/amd64\nsha256-GpEiFJdPTtSzE3lfKsyRYP6Llc9wbYOzOcFgxn1xOoU=\nhttps://gitlab.com/api/v4/projects/109/packages/generic/shfmt/3.6.0/shfmt.linux-amd64',), + ) + assert True + +# additional_dependencies: +# - |- +# linux/amd64 +# sha256-GpEiFJdPTtSzE3lfKsyRYP6Llc9wbYOzOcFgxn1xOoU= +# https://gitlab.com/api/v4/projects/109/packages/generic/shfmt/3.6.0/shfmt.linux-amd64 +# bin/shfmt \ No newline at end of file