Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jul 3, 2021
1 parent 017aafe commit 20872fa
Show file tree
Hide file tree
Showing 7 changed files with 556 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -18,3 +18,4 @@ src/_black_version.py
*.swp
.hypothesis/
venv/
.ipynb_checkpoints/
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -87,6 +87,7 @@ def get_long_description() -> str:
"colorama": ["colorama>=0.4.3"],
"python2": ["typed-ast>=1.4.2"],
"uvloop": ["uvloop>=0.15.2"],
"jupyter": ["ipython>=7.8.0", "tokenize-rt>=3.2.0"],
},
test_suite="tests.test_black",
classifiers=[
Expand Down
70 changes: 70 additions & 0 deletions src/black/__init__.py
@@ -1,4 +1,5 @@
import asyncio
import json
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from contextlib import contextmanager
from datetime import datetime
Expand Down Expand Up @@ -46,6 +47,12 @@
from black.files import wrap_stream_for_windows
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.handle_ipynb_magics import (
mask_cell,
unmask_cell,
remove_trailing_semicolon,
put_trailing_semicolon_back,
)


# lib2to3 fork
Expand Down Expand Up @@ -196,6 +203,14 @@ def validate_regex(
" when piping source on standard input)."
),
)
@click.option(
"--ipynb",
is_flag=True,
help=(
"Format all input files like ipynb notebooks regardless of file extension "
"(useful when piping source on standard input)."
),
)
@click.option(
"-S",
"--skip-string-normalization",
Expand Down Expand Up @@ -354,6 +369,7 @@ def main(
color: bool,
fast: bool,
pyi: bool,
ipynb: bool,
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
Expand Down Expand Up @@ -390,6 +406,7 @@ def main(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
is_ipynb=ipynb,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
Expand Down Expand Up @@ -584,6 +601,8 @@ def reformat_one(
if is_stdin:
if src.suffix == ".pyi":
mode = replace(mode, is_pyi=True)
elif src.suffix == ".ipynb":
mode = replace(mode, is_ipynb=True)
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
Expand Down Expand Up @@ -732,6 +751,8 @@ def format_file_in_place(
"""
if src.suffix == ".pyi":
mode = replace(mode, is_pyi=True)
elif src.suffix == ".ipynb":
mode = replace(mode, is_ipynb=True)

then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf:
Expand Down Expand Up @@ -825,6 +846,9 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
`mode` is passed to :func:`format_str`.
"""
if mode.is_ipynb:
return format_ipynb_string(src_contents, mode=mode, fast=fast)

if not src_contents.strip():
raise NothingChanged

Expand All @@ -848,6 +872,52 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
return dst_contents


def format_cell(src: str, *, mode: Mode) -> str:
src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
src
)
try:
masked_cell, replacements = mask_cell(src_without_trailing_semicolon)
except SyntaxError:
# Don't format, might be automagic or multi-line magic.
raise NothingChanged
formatted_masked_cell = format_str(masked_cell, mode=mode)
formatted_cell = unmask_cell(formatted_masked_cell, replacements)
new_src = put_trailing_semicolon_back(formatted_cell, has_trailing_semicolon)
new_src = new_src.rstrip("\n")
if new_src == src:
raise NothingChanged
return new_src


def format_ipynb_string(
src_contents: str, *, mode: Mode, fast: bool = False
) -> FileContent:
nb = json.loads(src_contents)
trailing_newline = src_contents[-1] == "\n"
modified = False
for _, cell in enumerate(nb["cells"]):
if cell.get("cell_type", None) == "code":
try:
src = "".join(cell["source"])
new_src = format_cell(src, mode=mode)
except NothingChanged:
pass
else:
cell["source"] = new_src.splitlines(keepends=True)
modified = True

if modified:
res = json.dumps(nb, indent=1, ensure_ascii=False)
if trailing_newline:
res = res + "\n"
if res == src_contents:
raise NothingChanged
return res
else:
raise NothingChanged


def format_str(src_contents: str, *, mode: Mode) -> FileContent:
"""Reformat a string and return new contents.
Expand Down

0 comments on commit 20872fa

Please sign in to comment.