Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip the first line of source code #3299

Merged
merged 12 commits into from Oct 6, 2022
2 changes: 2 additions & 0 deletions CHANGES.md
Expand Up @@ -28,6 +28,8 @@
<!-- Changes to how Black can be configured -->

- `.ipynb_checkpoints` directories are now excluded by default (#3293)
- Add `--skip-source-first-line` / `-x` option to ignore the first line of source code
while formatting (#3299)

### Packaging

Expand Down
3 changes: 3 additions & 0 deletions docs/usage_and_configuration/black_as_a_server.md
Expand Up @@ -50,6 +50,9 @@ is rejected with `HTTP 501` (Not Implemented).
The headers controlling how source code is formatted are:

- `X-Line-Length`: corresponds to the `--line-length` command line flag.
- `X-Skip-Source-First-Line`: corresponds to the `--skip-source-first-line` command line
flag. If present and its value is not an empty string, the first line of the source
code will be ignored.
- `X-Skip-String-Normalization`: corresponds to the `--skip-string-normalization`
command line flag. If present and its value is not the empty string, no string
normalization will be performed.
Expand Down
13 changes: 13 additions & 0 deletions src/black/__init__.py
Expand Up @@ -248,6 +248,12 @@ def validate_regex(
),
default=[],
)
@click.option(
"-x",
"--skip-source-first-line",
is_flag=True,
help="Skip the first line of the source code.",
)
@click.option(
"-S",
"--skip-string-normalization",
Expand Down Expand Up @@ -428,6 +434,7 @@ def main( # noqa: C901
pyi: bool,
ipynb: bool,
python_cell_magics: Sequence[str],
skip_source_first_line: bool,
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
Expand Down Expand Up @@ -528,6 +535,7 @@ def main( # noqa: C901
line_length=line_length,
is_pyi=pyi,
is_ipynb=ipynb,
skip_source_first_line=skip_source_first_line,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
Expand Down Expand Up @@ -790,7 +798,10 @@ def format_file_in_place(
mode = replace(mode, is_ipynb=True)

then = datetime.utcfromtimestamp(src.stat().st_mtime)
header = b""
with open(src, "rb") as buf:
if mode.skip_source_first_line:
header = buf.readline()
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
Expand All @@ -800,6 +811,8 @@ def format_file_in_place(
raise ValueError(
f"File '{src}' cannot be parsed as valid Jupyter notebook."
) from None
src_contents = header.decode(encoding) + src_contents
dst_contents = header.decode(encoding) + dst_contents

if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
Expand Down
2 changes: 2 additions & 0 deletions src/black/mode.py
Expand Up @@ -170,6 +170,7 @@ class Mode:
string_normalization: bool = True
is_pyi: bool = False
is_ipynb: bool = False
skip_source_first_line: bool = False
magic_trailing_comma: bool = True
experimental_string_processing: bool = False
python_cell_magics: Set[str] = field(default_factory=set)
Expand Down Expand Up @@ -208,6 +209,7 @@ def get_cache_key(self) -> str:
str(int(self.string_normalization)),
str(int(self.is_pyi)),
str(int(self.is_ipynb)),
str(int(self.skip_source_first_line)),
str(int(self.magic_trailing_comma)),
str(int(self.experimental_string_processing)),
str(int(self.preview)),
Expand Down
16 changes: 16 additions & 0 deletions src/blackd/__init__.py
Expand Up @@ -30,6 +30,7 @@
PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
LINE_LENGTH_HEADER = "X-Line-Length"
PYTHON_VARIANT_HEADER = "X-Python-Variant"
SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
PREVIEW = "X-Preview"
Expand All @@ -40,6 +41,7 @@
PROTOCOL_VERSION_HEADER,
LINE_LENGTH_HEADER,
PYTHON_VARIANT_HEADER,
SKIP_SOURCE_FIRST_LINE,
SKIP_STRING_NORMALIZATION_HEADER,
SKIP_MAGIC_TRAILING_COMMA,
PREVIEW,
Expand Down Expand Up @@ -111,6 +113,9 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
skip_magic_trailing_comma = bool(
request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
)
skip_source_first_line = bool(
request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
)
preview = bool(request.headers.get(PREVIEW, False))
fast = False
if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
Expand All @@ -119,6 +124,7 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
target_versions=versions,
is_pyi=pyi,
line_length=line_length,
skip_source_first_line=skip_source_first_line,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
preview=preview,
Expand All @@ -128,6 +134,12 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
req_str = req_bytes.decode(charset)
then = datetime.utcnow()

header = ""
if skip_source_first_line:
first_newline_position: int = req_str.find("\n") + 1
header = req_str[:first_newline_position]
req_str = req_str[first_newline_position:]

loop = asyncio.get_event_loop()
formatted_str = await loop.run_in_executor(
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
Expand All @@ -140,6 +152,10 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
if formatted_str == req_str:
raise black.NothingChanged

# Put the source first line back
req_str = header + req_str
formatted_str = header + formatted_str

# Only output the diff in the HTTP response
only_diff = bool(request.headers.get(DIFF_HEADER, False))
if only_diff:
Expand Down
2 changes: 2 additions & 0 deletions tests/data/miscellaneous/invalid_header.py
@@ -0,0 +1,2 @@
This is not valid Python syntax
y = "This is valid syntax"
24 changes: 24 additions & 0 deletions tests/test_black.py
Expand Up @@ -341,6 +341,30 @@ def test_string_quotes(self) -> None:
black.assert_equivalent(source, not_normalized)
black.assert_stable(source, not_normalized, mode=mode)

def test_skip_source_first_line(self) -> None:
aaossa marked this conversation as resolved.
Show resolved Hide resolved
source, _ = read_data("miscellaneous", "invalid_header")
tmp_file = Path(black.dump_to_file(source))
# Full source should fail (invalid syntax at header)
self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
# So, skipping the first line should work
result = BlackRunner().invoke(
black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
)
self.assertEqual(result.exit_code, 0)
with open(tmp_file, encoding="utf8") as f:
actual = f.read()
self.assertFormatEqual(source, actual)

def test_skip_source_first_line_when_mixing_newlines(self) -> None:
code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
with TemporaryDirectory() as workspace:
test_file = Path(workspace) / "skip_header.py"
test_file.write_bytes(code_mixing_newlines)
mode = replace(DEFAULT_MODE, skip_source_first_line=True)
ff(test_file, mode=mode, write_back=black.WriteBack.YES)
self.assertEqual(test_file.read_bytes(), expected)

def test_skip_magic_trailing_comma(self) -> None:
source, _ = read_data("simple_cases", "expression")
expected, _ = read_data(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_blackd.py
Expand Up @@ -177,6 +177,20 @@ async def test_blackd_invalid_line_length(self) -> None:
)
self.assertEqual(response.status, 400)

@unittest_run_loop
async def test_blackd_skip_first_source_line(self) -> None:
invalid_first_line = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
expected_result = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
response = await self.client.post("/", data=invalid_first_line)
self.assertEqual(response.status, 400)
response = await self.client.post(
"/",
data=invalid_first_line,
headers={blackd.SKIP_SOURCE_FIRST_LINE: "true"},
)
self.assertEqual(response.status, 200)
self.assertEqual(await response.read(), expected_result)

@unittest_run_loop
async def test_blackd_preview(self) -> None:
response = await self.client.post(
Expand Down