Skip to content

Commit

Permalink
Implementing mypyc support pt. 2 (#2431)
Browse files Browse the repository at this point in the history
  • Loading branch information
ichard26 committed Nov 16, 2021
1 parent 1d72600 commit 1178918
Show file tree
Hide file tree
Showing 22 changed files with 310 additions and 168 deletions.
10 changes: 9 additions & 1 deletion mypy.ini
Expand Up @@ -3,7 +3,6 @@
# free to run mypy on Windows, Linux, or macOS and get consistent
# results.
python_version=3.6
platform=linux

mypy_path=src

Expand All @@ -24,6 +23,10 @@ warn_redundant_casts=True
warn_unused_ignores=True
disallow_any_generics=True

# Unreachable blocks have been an issue when compiling mypyc, let's try
# to avoid 'em in the first place.
warn_unreachable=True

# The following are off by default. Flip them on if you feel
# adventurous.
disallow_untyped_defs=True
Expand All @@ -32,6 +35,11 @@ check_untyped_defs=True
# No incremental mode
cache_dir=/dev/null

[mypy-black]
# The following is because of `patch_click()`. Remove when
# we drop Python 3.6 support.
warn_unused_ignores=False

[mypy-black_primer.*]
# Until we're not supporting 3.6 primer needs this
disallow_any_generics=False
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Expand Up @@ -33,3 +33,6 @@ optional-tests = [
"no_blackd: run when `d` extra NOT installed",
"no_jupyter: run when `jupyter` extra NOT installed",
]
markers = [
"incompatible_with_mypyc: run when testing mypyc compiled black"
]
47 changes: 36 additions & 11 deletions setup.py
Expand Up @@ -5,6 +5,7 @@

assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+"
from pathlib import Path # noqa E402
from typing import List # noqa: E402

CURRENT_DIR = Path(__file__).parent
sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta
Expand All @@ -18,6 +19,17 @@ def get_long_description() -> str:
)


def find_python_files(base: Path) -> List[Path]:
files = []
for entry in base.iterdir():
if entry.is_file() and entry.suffix == ".py":
files.append(entry)
elif entry.is_dir():
files.extend(find_python_files(entry))

return files


USE_MYPYC = False
# To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH
if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc":
Expand All @@ -27,21 +39,34 @@ def get_long_description() -> str:
USE_MYPYC = True

if USE_MYPYC:
from mypyc.build import mypycify

src = CURRENT_DIR / "src"
# TIP: filepaths are normalized to use forward slashes and are relative to ./src/
# before being checked against.
blocklist = [
# Not performance sensitive, so save bytes + compilation time:
"blib2to3/__init__.py",
"blib2to3/pgen2/__init__.py",
"black/output.py",
"black/concurrency.py",
"black/files.py",
"black/report.py",
# Breaks the test suite when compiled (and is also useless):
"black/debug.py",
# Compiled modules can't be run directly and that's a problem here:
"black/__main__.py",
]
discovered = []
# black-primer and blackd have no good reason to be compiled.
discovered.extend(find_python_files(src / "black"))
discovered.extend(find_python_files(src / "blib2to3"))
mypyc_targets = [
"src/black/__init__.py",
"src/blib2to3/pytree.py",
"src/blib2to3/pygram.py",
"src/blib2to3/pgen2/parse.py",
"src/blib2to3/pgen2/grammar.py",
"src/blib2to3/pgen2/token.py",
"src/blib2to3/pgen2/driver.py",
"src/blib2to3/pgen2/pgen.py",
str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist
]

from mypyc.build import mypycify

opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
ext_modules = mypycify(mypyc_targets, opt_level=opt_level)
ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True)
else:
ext_modules = []

Expand Down
23 changes: 19 additions & 4 deletions src/black/__init__.py
Expand Up @@ -30,8 +30,9 @@
Union,
)

from dataclasses import replace
import click
from dataclasses import replace
from mypy_extensions import mypyc_attr

from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER
Expand Down Expand Up @@ -66,6 +67,8 @@

from _black_version import version as __version__

COMPILED = Path(__file__).suffix in (".pyd", ".so")

# types
FileContent = str
Encoding = str
Expand Down Expand Up @@ -177,7 +180,12 @@ def validate_regex(
raise click.BadParameter("Not a valid regular expression") from None


@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.command(
context_settings=dict(help_option_names=["-h", "--help"]),
# While Click does set this field automatically using the docstring, mypyc
# (annoyingly) strips 'em so we need to set it here too.
help="The uncompromising code formatter.",
)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
Expand Down Expand Up @@ -346,7 +354,10 @@ def validate_regex(
" due to exclusion patterns."
),
)
@click.version_option(version=__version__)
@click.version_option(
version=__version__,
message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
)
@click.argument(
"src",
nargs=-1,
Expand Down Expand Up @@ -387,7 +398,7 @@ def main(
experimental_string_processing: bool,
quiet: bool,
verbose: bool,
required_version: str,
required_version: Optional[str],
include: Pattern[str],
exclude: Optional[Pattern[str]],
extend_exclude: Optional[Pattern[str]],
Expand Down Expand Up @@ -655,6 +666,9 @@ def reformat_one(
report.failed(src, str(exc))


# diff-shades depends on being to monkeypatch this function to operate. I know it's
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
@mypyc_attr(patchable=True)
def reformat_many(
sources: Set[Path],
fast: bool,
Expand All @@ -669,6 +683,7 @@ def reformat_many(
worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
assert worker_count is not None
worker_count = min(worker_count, 60)
try:
executor = ProcessPoolExecutor(max_workers=worker_count)
Expand Down
2 changes: 1 addition & 1 deletion src/black/brackets.py
Expand Up @@ -49,7 +49,7 @@
DOT_PRIORITY: Final = 1


class BracketMatchError(KeyError):
class BracketMatchError(Exception):
"""Raised when an opening bracket is unable to be matched to a closing bracket."""


Expand Down
15 changes: 10 additions & 5 deletions src/black/comments.py
@@ -1,8 +1,14 @@
import sys
from dataclasses import dataclass
from functools import lru_cache
import regex as re
from typing import Iterator, List, Optional, Union

if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final

from blib2to3.pytree import Node, Leaf
from blib2to3.pgen2 import token

Expand All @@ -12,11 +18,10 @@
# types
LN = Union[Leaf, Node]


FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
FMT_PASS = {*FMT_OFF, *FMT_SKIP}
FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"}
FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"}
FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP}
FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"}


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion src/black/files.py
Expand Up @@ -17,6 +17,7 @@
TYPE_CHECKING,
)

from mypy_extensions import mypyc_attr
from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
import tomli
Expand Down Expand Up @@ -88,13 +89,14 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
return None


@mypyc_attr(patchable=True)
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
"""Parse a pyproject toml file, pulling out relevant parts for Black
If parsing fails, will raise a tomli.TOMLDecodeError
"""
with open(path_config, encoding="utf8") as f:
pyproject_toml = tomli.load(f) # type: ignore # due to deprecated API usage
pyproject_toml = tomli.loads(f.read())
config = pyproject_toml.get("tool", {}).get("black", {})
return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}

Expand Down
13 changes: 7 additions & 6 deletions src/black/handle_ipynb_magics.py
Expand Up @@ -333,7 +333,7 @@ def header(self) -> str:
return f"%%{self.name}"


@dataclasses.dataclass
# ast.NodeVisitor + dataclass = breakage under mypyc.
class CellMagicFinder(ast.NodeVisitor):
"""Find cell magics.
Expand All @@ -352,7 +352,8 @@ class CellMagicFinder(ast.NodeVisitor):
and we look for instances of the latter.
"""

cell_magic: Optional[CellMagic] = None
def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
self.cell_magic = cell_magic

def visit_Expr(self, node: ast.Expr) -> None:
"""Find cell magic, extract header and body."""
Expand All @@ -372,7 +373,8 @@ class OffsetAndMagic:
magic: str


@dataclasses.dataclass
# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
# as mypyc will generate broken code.
class MagicFinder(ast.NodeVisitor):
"""Visit cell to look for get_ipython calls.
Expand All @@ -392,9 +394,8 @@ class MagicFinder(ast.NodeVisitor):
types of magics).
"""

magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field(
default_factory=lambda: collections.defaultdict(list)
)
def __init__(self) -> None:
self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)

def visit_Assign(self, node: ast.Assign) -> None:
"""Look for system assign magics.
Expand Down
25 changes: 17 additions & 8 deletions src/black/linegen.py
Expand Up @@ -5,8 +5,6 @@
import sys
from typing import Collection, Iterator, List, Optional, Set, Union

from dataclasses import dataclass, field

from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
Expand Down Expand Up @@ -40,17 +38,20 @@ class CannotSplit(CannotTransform):
"""A readable split that fits the allotted line length is impossible."""


@dataclass
# This isn't a dataclass because @dataclass + Generic breaks mypyc.
# See also https://github.com/mypyc/mypyc/issues/827.
class LineGenerator(Visitor[Line]):
"""Generates reformatted Line objects. Empty lines are not emitted.
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""

mode: Mode
remove_u_prefix: bool = False
current_line: Line = field(init=False)
def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None:
self.mode = mode
self.remove_u_prefix = remove_u_prefix
self.current_line: Line
self.__post_init__()

def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
Expand Down Expand Up @@ -339,7 +340,9 @@ def transform_line(
transformers = [left_hand_split]
else:

def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
def _rhs(
self: object, line: Line, features: Collection[Feature]
) -> Iterator[Line]:
"""Wraps calls to `right_hand_split`.
The calls increasingly `omit` right-hand trailers (bracket pairs with
Expand All @@ -366,6 +369,12 @@ def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
line, line_length=mode.line_length, features=features
)

# HACK: nested functions (like _rhs) compiled by mypyc don't retain their
# __name__ attribute which is needed in `run_transformer` further down.
# Unfortunately a nested class breaks mypyc too. So a class must be created
# via type ... https://github.com/mypyc/mypyc/issues/884
rhs = type("rhs", (), {"__call__": _rhs})()

if mode.experimental_string_processing:
if line.inside_brackets:
transformers = [
Expand Down Expand Up @@ -980,7 +989,7 @@ def run_transformer(
result.extend(transform_line(transformed_line, mode=mode, features=features))

if (
transform.__name__ != "rhs"
transform.__class__.__name__ != "rhs"
or not line.bracket_tracker.invisible
or any(bracket.value for bracket in line.bracket_tracker.invisible)
or line.contains_multiline_strings()
Expand Down
3 changes: 2 additions & 1 deletion src/black/mode.py
Expand Up @@ -6,6 +6,7 @@

from dataclasses import dataclass, field
from enum import Enum
from operator import attrgetter
from typing import Dict, Set

from black.const import DEFAULT_LINE_LENGTH
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_cache_key(self) -> str:
if self.target_versions:
version_str = ",".join(
str(version.value)
for version in sorted(self.target_versions, key=lambda v: v.value)
for version in sorted(self.target_versions, key=attrgetter("value"))
)
else:
version_str = "-"
Expand Down

0 comments on commit 1178918

Please sign in to comment.