Skip to content

Commit

Permalink
This implements pre-commit#1453 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangfufu committed Sep 8, 2023
1 parent e2c6a82 commit 161699e
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pre_commit/all_languages.py
Expand Up @@ -7,6 +7,7 @@
from pre_commit.languages import docker
from pre_commit.languages import docker_image
from pre_commit.languages import dotnet
from pre_commit.languages import download
from pre_commit.languages import fail
from pre_commit.languages import golang
from pre_commit.languages import haskell
Expand All @@ -30,6 +31,7 @@
'docker': docker,
'docker_image': docker_image,
'dotnet': dotnet,
'download': download,
'fail': fail,
'golang': golang,
'haskell': haskell,
Expand Down
239 changes: 239 additions & 0 deletions pre_commit/languages/download.py
@@ -0,0 +1,239 @@
from __future__ import annotations

from base64 import standard_b64decode as b64decode
from base64 import standard_b64encode as b64encode
import contextlib
from dataclasses import dataclass
import hashlib
import os.path
from os import chmod
from pathlib import Path, PurePath
import platform
from typing import (
AnyStr,
Bool,
Byte,
Generator,
IO,
Iterator,
Literal,
Sequence,
Tuple,
)
from urllib.parse import urlparse
from urllib.request import urlopen

from pre_commit import lang_base
from pre_commit.envcontext import Var
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import envcontext

from pre_commit.prefix import Prefix

ENVIRONMENT_DIR = 'download'

def get_env_patch(target_dir: str) -> PatchesT:
return (
('PATH', (target_dir, os.pathsep, Var('PATH'))),
)

@contextlib.contextmanager
def in_env(prefix: Prefix, version: str) -> Generator[None, None, None]:
envdir = lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version)
with envcontext(get_env_patch(envdir)):
yield

@dataclass(frozen=True)
class Platform:
value: str

def __post_init__(self) -> None:
os, cpu = self.parts

if os not in {"linux", "darwin", "windows"}:
raise ValueError(f"invalid operating system `{os}`")

if cpu not in {"amd64", "aarch64"}:
raise ValueError(f"invalid CPU `{cpu}`")

@property
def parts(self) -> Tuple[str, str]:
first, second = self.value.split("/", 1)
return (first, second)

@property
def os(self) -> Literal["linux", "darwin", "windows", "dragonfly", "freebsd"]:
os, _ = self.parts
return os # type: ignore

@property
def cpu(self) -> Literal["arm64", "arm64be", "arm", "386", "amd64", "ppc", "ppc64", "ppc64le"]:
_, cpu = self.parts
return cpu # type: ignore

@staticmethod
def host() -> Platform:
os = {
"Linux": "linux",
"Darwin": "darwin",
"Windows": "windows",
"DragonFly": "dragonfly",
"FreeBSD": "freebsd"
}[platform.system()]
cpu = {
"aarch64": "arm64",
"aarch64_be": "arm64be",
"arm": "arm",
"i386": "386",
"i686": "386",
"x86_64": "amd64",
"ppc": "ppc",
"ppc64": "ppc64",
"ppc64le": "ppc64le"
}[platform.machine()]
return Platform(f"{os}/{cpu}")

def __eq__(self, other: Platform) -> Bool:
return other.os == self.os and other.cpu == self.cpu

class ContentCorruptionError(Exception):
pass


class SRI:
def __init__(self, value: str):
self._value = value
self._algorithm, self._checksum = self.value.split("-", 1)

assert self._algorithm in {"sha256", "sha384", "sha512"}

if b64encode(b64decode(self.checksum)) != self.checksum:
raise ValueError("Invalid checksum string.")
if (self.algorithm == "sha256" and len(self.checksum) != 44) or \
(self.algorithm == "sha384" and len(self.checksum) != 64) or \
(self.algorithm == "sha384" and len(self.checksum) != 88):
raise ValueError(f"Invalid checksum string length for {self.algorithm}")

self._hasher = hashlib.new(self.algorithm)

@property
def value(self) -> str:
return self._value

@property
def algorithm(self) -> Literal["sha256", "sha384", "sha512"]:
return self._algorithm

@property
def checksum(self) -> str:
return self._checksum

@property
def hexdigest(self) -> str:
return self._hasher.hexdigest()

def update(self, data: Byte) -> None:
return self._hasher.update(data)

def check_content(self) -> None:
if self._hasher.hexdigest != self.checksum:
raise ContentCorruptionError("SRI checker cotent corruption.")


@dataclass(frozen=True)
class URI:
value: str

def __post_init__(self) -> None:
url = urlparse(self.value)
if not all([url.scheme, url.netloc]):
raise ValueError(f"Invalid URI: {self.value}")

def __str__(self) -> str:
return self.value

@dataclass(frozen=True)
class Metadata:
value: str

@property
def parts(self) -> Tuple[str, str, str, str]:
first, second, third, fourth = self.value.splitlines()
return (first, second, third, fourth)

@property
def platform(self) -> Platform:
platform, _, _, _ = self.parts
return Platform(platform)

@property
def sri(self) -> SRI:
_, sri, _, _ = self.parts
return SRI(sri)

@property
def url(self) -> URI:
_, _, url, _ = self.parts
return url

@property
def filename(self) -> PurePath:
_, _, _, path = self.parts
return PurePath(path)

def io_buffer(io: IO[AnyStr], size: int = 4096) -> Iterator[AnyStr]:
"""Works around not being able to use the walrus operator.
:param io: the stream to read chunks from
:param size: the size of chunks to read
:returns: iterator of the chunks
"""
buffer = io.read(size)
while buffer:
yield buffer
buffer = io.read(size)

def download(uri: URI, sri: SRI, filename: Path) -> None:
filename.parent.mkdir(...)
# do the download
hasher = hashlib.new(sri.algorithm)
with urlopen(str(uri)) as ws:
with filename.open("wb") as fp:
for buf in io_buffer(ws):
fp.write(buf)
sri.update(buf)
fp.flush()
sri.check_content()
chmod(filename)


def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
host= Platform.host()
for dep in additional_dependencies:
m = Metadata(dep)
if host == m.platform:
envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version))
filename = envdir / m.filename
download(m.uri, m.sri, filename )
srisum = envdir / "health.srisum"
with srisum.open("w", encoding="utf8") as stream:
stream.write(f"{filename} {m.sri}\n")
return
raise KeyError(f"Failed to find platform `{host}` in `additional_dependencies`: {additional_dependencies}")


def health_check(prefix: Prefix, version: str) -> str | None:
envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version))
srisum = envdir / "health.srisum"
with srisum.open(encoding="utf8") as stream:
for line in stream:
p, s = line.split(" ", 1)
filename = envdir / p
sri = SRI(s)
with filename.open("rb") as fp:
sri.update(fp)
sri.check_content()
30 changes: 30 additions & 0 deletions tests/languages/download_test.py
@@ -0,0 +1,30 @@
from __future__ import annotations

from pre_commit.languages import download
from testing.language_helpers import run_language


def test_download_hooks(tmp_path):
ret = run_language(
tmp_path,
download,
'watch out for',
file_args=('bunnies',),
)
assert ret == (1, b'watch out for\n\nbunnies\n')

def test_download_dependencies(tmp_path):
ret = run_language(
tmp_path,
download,
'hello',
deps=('linux/amd64\nsha256-GpEiFJdPTtSzE3lfKsyRYP6Llc9wbYOzOcFgxn1xOoU=\nhttps://gitlab.com/api/v4/projects/109/packages/generic/shfmt/3.6.0/shfmt.linux-amd64',),
)
assert True

# additional_dependencies:
# - |-
# linux/amd64
# sha256-GpEiFJdPTtSzE3lfKsyRYP6Llc9wbYOzOcFgxn1xOoU=
# https://gitlab.com/api/v4/projects/109/packages/generic/shfmt/3.6.0/shfmt.linux-amd64
# bin/shfmt

0 comments on commit 161699e

Please sign in to comment.