From 20872fa9fb3ed4fc993794da0cc7b50b4f95df2c Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Fri, 2 Jul 2021 10:09:21 +0100 Subject: [PATCH] wip --- .gitignore | 1 + setup.py | 1 + src/black/__init__.py | 70 +++++++ src/black/handle_ipynb_magics.py | 262 ++++++++++++++++++++++++++ src/black/mode.py | 2 + tests/data/notebook_for_testing.ipynb | 121 ++++++++++++ tests/test_ipynb.py | 99 ++++++++++ 7 files changed, 556 insertions(+) create mode 100644 src/black/handle_ipynb_magics.py create mode 100644 tests/data/notebook_for_testing.ipynb create mode 100644 tests/test_ipynb.py diff --git a/.gitignore b/.gitignore index ab796ce4cd0..f81bce8fd4e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ src/_black_version.py *.swp .hypothesis/ venv/ +.ipynb_checkpoints/ diff --git a/setup.py b/setup.py index 5549ae35342..4e024fdc2ec 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,7 @@ def get_long_description() -> str: "colorama": ["colorama>=0.4.3"], "python2": ["typed-ast>=1.4.2"], "uvloop": ["uvloop>=0.15.2"], + "jupyter": ["ipython>=7.8.0", "tokenize-rt>=3.2.0"], }, test_suite="tests.test_black", classifiers=[ diff --git a/src/black/__init__.py b/src/black/__init__.py index 8e2123d50cc..f442ff85002 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,4 +1,5 @@ import asyncio +import json from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from contextlib import contextmanager from datetime import datetime @@ -46,6 +47,12 @@ from black.files import wrap_stream_for_windows from black.parsing import InvalidInput # noqa F401 from black.parsing import lib2to3_parse, parse_ast, stringify_ast +from black.handle_ipynb_magics import ( + mask_cell, + unmask_cell, + remove_trailing_semicolon, + put_trailing_semicolon_back, +) # lib2to3 fork @@ -196,6 +203,14 @@ def validate_regex( " when piping source on standard input)." ), ) +@click.option( + "--ipynb", + is_flag=True, + help=( + "Format all input files like ipynb notebooks regardless of file extension " + "(useful when piping source on standard input)." + ), +) @click.option( "-S", "--skip-string-normalization", @@ -354,6 +369,7 @@ def main( color: bool, fast: bool, pyi: bool, + ipynb: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, @@ -390,6 +406,7 @@ def main( target_versions=versions, line_length=line_length, is_pyi=pyi, + is_ipynb=ipynb, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, @@ -584,6 +601,8 @@ def reformat_one( if is_stdin: if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -732,6 +751,8 @@ def format_file_in_place( """ if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -825,6 +846,9 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `mode` is passed to :func:`format_str`. """ + if mode.is_ipynb: + return format_ipynb_string(src_contents, mode=mode, fast=fast) + if not src_contents.strip(): raise NothingChanged @@ -848,6 +872,52 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo return dst_contents +def format_cell(src: str, *, mode: Mode) -> str: + src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( + src + ) + try: + masked_cell, replacements = mask_cell(src_without_trailing_semicolon) + except SyntaxError: + # Don't format, might be automagic or multi-line magic. + raise NothingChanged + formatted_masked_cell = format_str(masked_cell, mode=mode) + formatted_cell = unmask_cell(formatted_masked_cell, replacements) + new_src = put_trailing_semicolon_back(formatted_cell, has_trailing_semicolon) + new_src = new_src.rstrip("\n") + if new_src == src: + raise NothingChanged + return new_src + + +def format_ipynb_string( + src_contents: str, *, mode: Mode, fast: bool = False +) -> FileContent: + nb = json.loads(src_contents) + trailing_newline = src_contents[-1] == "\n" + modified = False + for _, cell in enumerate(nb["cells"]): + if cell.get("cell_type", None) == "code": + try: + src = "".join(cell["source"]) + new_src = format_cell(src, mode=mode) + except NothingChanged: + pass + else: + cell["source"] = new_src.splitlines(keepends=True) + modified = True + + if modified: + res = json.dumps(nb, indent=1, ensure_ascii=False) + if trailing_newline: + res = res + "\n" + if res == src_contents: + raise NothingChanged + return res + else: + raise NothingChanged + + def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. diff --git a/src/black/handle_ipynb_magics.py b/src/black/handle_ipynb_magics.py new file mode 100644 index 00000000000..2579719dd15 --- /dev/null +++ b/src/black/handle_ipynb_magics.py @@ -0,0 +1,262 @@ +import ast +from typing import Dict + +import secrets +from tokenize_rt import ( + src_to_tokens, + tokens_to_src, + NON_CODING_TOKENS, + reversed_enumerate, +) +from typing import NamedTuple, List, Tuple +import collections + +from typing import Optional + + +class Replacement(NamedTuple): + mask: str + src: str + + +class UnsupportedMagic(UserWarning): + """Raise when Magic (e.g. `a = b??`) is not supported.""" + + +def remove_trailing_semicolon(src: str) -> Tuple[str, bool]: + # ok, let's do this one first + tokens = src_to_tokens(src) + trailing_semicolon = False + for idx, token in reversed_enumerate(tokens): + if token.name in NON_CODING_TOKENS or token.name == "NEWLINE" or not token.src: + continue + if token.name == "OP" and token.src == ";": + del tokens[idx] + trailing_semicolon = True + break + if not trailing_semicolon: + return src, False + return tokens_to_src(tokens), True + + +def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str: + if not has_trailing_semicolon: + return src + tokens = src_to_tokens(src) + for idx, token in reversed_enumerate(tokens): + if token.name in NON_CODING_TOKENS or token.name == "NEWLINE" or not token.src: + continue + tokens[idx] = token._replace(src=token.src + ";") + break + else: # pragma: nocover + raise AssertionError("Unreachable code") + return str(tokens_to_src(tokens)) + + +def mask_cell(src: str) -> Tuple[str, List[Replacement]]: + replacements: List[Replacement] = [] + try: + ast.parse(src) + except SyntaxError: + # Might be able to parse it with IPython + pass + else: + # Syntax is fine, nothing to mask + return src, replacements + + from IPython.core.inputtransformer2 import TransformerManager + + transformer_manager = TransformerManager() + transformed = transformer_manager.transform_cell(src) + + transformed, cell_magic_replacements = replace_cell_magics(transformed) + replacements += cell_magic_replacements + + transformed = transformer_manager.transform_cell(transformed) + try: + transformed, magic_replacements = replace_magics(transformed) + except UnsupportedMagic: + # will be ignored upstream + raise SyntaxError + + replacements += magic_replacements + + return transformed, replacements + + +def get_token(src: str, *, is_cell_magic: bool = False) -> str: + token = secrets.token_hex(3) + while token in src: # pragma: nocover + token = secrets.token_hex(3) + if is_cell_magic: + return f"# {token}" + return f'str("{token}")' + + +def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]: + replacements: List[Replacement] = [] + + tree = ast.parse(src) + + cell_magic_finder = CellMagicFinder() + cell_magic_finder.visit(tree) + if not cell_magic_finder.header: + return src, replacements + mask = get_token(src, is_cell_magic=True) + replacements.append(Replacement(mask=mask, src=cell_magic_finder.header)) + return f"{mask}\n{cell_magic_finder.body}", replacements + + +def replace_magics(src: str) -> Tuple[str, List[Replacement]]: + replacements = [] + + tree = ast.parse(src) + + magic_finder = MagicFinder() + magic_finder.visit(tree) + new_srcs = [] + for i, line in enumerate(src.splitlines(), start=1): + if i in magic_finder.magics: + magics = magic_finder.magics[i] + if len(magics) != 1: + raise UnsupportedMagic + col_offset, magic = magic_finder.magics[i][0] + mask = get_token(src) + replacements.append(Replacement(mask=mask, src=magic)) + line = line[:col_offset] + mask + new_srcs.append(line) + return "\n".join(new_srcs), replacements + + +def unmask_cell(src: str, replacements: List[Replacement]) -> str: + for replacement in replacements: + src = src.replace(replacement.mask, replacement.src) + return src + + +def _is_ipython_magic(node: ast.expr) -> bool: + """Check if attribute is IPython magic.""" + return ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "get_ipython" + ) + + +class CellMagicFinder(ast.NodeVisitor): + """Find cell magics.""" + + def __init__(self) -> None: + """Record where cell magics occur.""" + self.header: Optional[str] = None + self.body: Optional[str] = None + + def visit_Expr(self, node: ast.Expr) -> None: # pylint: disable=C0103 + """ + Find cell magic, extract header and body. + Raises + ------ + AssertionError + Defensive check. + """ + if ( + isinstance(node.value, ast.Call) + and _is_ipython_magic(node.value.func) + and isinstance(node.value.func, ast.Attribute) + and node.value.func.attr == "run_cell_magic" + ): + args = [] + for arg in node.value.args: + assert isinstance(arg, ast.Str) + args.append(arg.s) + header: Optional[str] = f"%%{args[0]}" + if args[1]: + assert header is not None + header += f" {args[1]}" + self.header = header + self.body = args[2] + self.generic_visit(node) + + +class MagicFinder(ast.NodeVisitor): + """Visit cell to look for get_ipython calls.""" + + def __init__(self) -> None: + """Magics will record where magics occur.""" + self.magics: Dict[int, List[Tuple[int, str]]] = collections.defaultdict(list) + + def visit_Assign(self, node: ast.Assign) -> None: # pylint: disable=C0103,R0912 + """ + Get source to replace ipython magic with. + Parameters + ---------- + node + Function call. + Raises + ------ + AssertionError + Defensive check. + """ + if ( + isinstance(node.value, ast.Call) + and _is_ipython_magic(node.value.func) + and isinstance(node.value.func, ast.Attribute) + and node.value.func.attr == "getoutput" + ): + args = [] + for arg in node.value.args: + assert isinstance(arg, ast.Str) + args.append(arg.s) + assert args + src = f"!{args[0]}" + self.magics[node.value.lineno].append( + ( + node.value.col_offset, + src, + ) + ) + self.generic_visit(node) + + def visit_Expr(self, node: ast.Expr) -> None: # pylint: disable=C0103,R0912 + """ + Get source to replace ipython magic with. + Parameters + ---------- + node + Function call. + Raises + ------ + AssertionError + Defensive check. + """ + if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func): + assert isinstance(node.value.func, ast.Attribute) # help mypy + args = [] + for arg in node.value.args: + assert isinstance(arg, ast.Str) + args.append(arg.s) + assert args + if node.value.func.attr == "run_line_magic": + if args[0] == "pinfo": + src = f"?{args[1]}" + elif args[0] == "pinfo2": + src = f"??{args[1]}" + else: + src = f"%{args[0]}" + if args[1]: + assert src is not None + src += f" {args[1]}" + elif node.value.func.attr == "system": + src = f"!{args[0]}" + elif node.value.func.attr == "getoutput": + src = f"!!{args[0]}" + else: + raise UnsupportedMagic + self.magics[node.value.lineno].append( + ( + node.value.col_offset, + src, + ) + ) + self.generic_visit(node) diff --git a/src/black/mode.py b/src/black/mode.py index e2ce322da5c..0b7624eaf8a 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -101,6 +101,7 @@ class Mode: line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True is_pyi: bool = False + is_ipynb: bool = False magic_trailing_comma: bool = True experimental_string_processing: bool = False @@ -117,6 +118,7 @@ def get_cache_key(self) -> str: str(self.line_length), str(int(self.string_normalization)), str(int(self.is_pyi)), + str(int(self.is_ipynb)), str(int(self.magic_trailing_comma)), str(int(self.experimental_string_processing)), ] diff --git a/tests/data/notebook_for_testing.ipynb b/tests/data/notebook_for_testing.ipynb new file mode 100644 index 00000000000..3d0572f3b41 --- /dev/null +++ b/tests/data/notebook_for_testing.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-flake8" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import glob\n", + "\n", + "import nbqa" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Some markdown cell containing \\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "flake8-skip" + ] + }, + "outputs": [], + "source": [ + "%%time foo\n", + "def hello(name: str = \"world\\n\"):\n", + " \"\"\"\n", + " Greet user.\n", + "\n", + " Examples\n", + " --------\n", + " >>> hello()\n", + " 'hello world\\\\n'\n", + "\n", + " >>> hello(\"goodbye\")\n", + " 'hello goodbye'\n", + " \"\"\"\n", + "\n", + " return 'hello {}'.format(name)\n", + "\n", + "\n", + "!ls\n", + "hello(3) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " %%bash\n", + "\n", + " pwd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from random import randint\n", + "\n", + "if __debug__:\n", + " %time randint(5,10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pprint\n", + "import sys\n", + "\n", + "if __debug__:\n", + " pretty_print_object = pprint.PrettyPrinter(\n", + " indent=4, width=80, stream=sys.stdout, compact=True, depth=5\n", + " )\n", + "\n", + "pretty_print_object.isreadable([\"Hello\", \"World\"])" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_ipynb.py b/tests/test_ipynb.py new file mode 100644 index 00000000000..a41c76a6a60 --- /dev/null +++ b/tests/test_ipynb.py @@ -0,0 +1,99 @@ +from black import NothingChanged, format_cell +from tests.util import DEFAULT_MODE +import pytest + + +def test_noop() -> None: + src = 'foo = "a"' + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_trailing_semicolon() -> None: + src = 'foo = "a" ;' + result = format_cell(src, mode=DEFAULT_MODE) + expected = 'foo = "a";' + assert result == expected + + +def test_trailing_semicolon_noop() -> None: + src = 'foo = "a";' + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_cell_magic() -> None: + src = "%%time\nfoo =bar" + result = format_cell(src, mode=DEFAULT_MODE) + expected = "%%time\nfoo = bar" + assert result == expected + + +def test_cell_magic_noop() -> None: + src = "%%time\n2 + 2" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +@pytest.mark.parametrize( + "src, expected", + ( + pytest.param("ls =!ls", "ls = !ls", id="System assignment"), + pytest.param("!ls\n'foo'", '!ls\n"foo"', id="System call"), + pytest.param("!!ls\n'foo'", '!!ls\n"foo"', id="Other system call"), + pytest.param("?str\n'foo'", '?str\n"foo"', id="Help"), + pytest.param("??str\n'foo'", '??str\n"foo"', id="Other help"), + pytest.param( + "%matplotlib inline\n'foo'", + '%matplotlib inline\n"foo"', + id="Line magic with argument", + ), + pytest.param("%time\n'foo'", '%time\n"foo"', id="Line magic without argument"), + ), +) +def test_magic(src: str, expected: str) -> None: + result = format_cell(src, mode=DEFAULT_MODE) + assert result == expected + + +def test_set_input() -> None: + src = "a = b??" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_magic_noop() -> None: + src = "ls = !ls" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_cell_magic_with_magic() -> None: + src = "%%t -n1\nls =!ls" + result = format_cell(src, mode=DEFAULT_MODE) + expected = "%%t -n1\nls = !ls" + assert result == expected + + +def test_cell_magic_with_magic_noop() -> None: + src = "%%t -n1\nls = !ls" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_automagic() -> None: + src = "pip install black" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_cell_magic_with_invalid_body() -> None: + src = "%%time\nif True" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE) + + +def test_empty_cell() -> None: + src = "" + with pytest.raises(NothingChanged): + format_cell(src, mode=DEFAULT_MODE)