Skip to content

Commit

Permalink
Merge pull request #2017 from XuehaiPan/black-pyi
Browse files Browse the repository at this point in the history
Fix `black` compatibility for `.pyi` type stub files
  • Loading branch information
staticdev committed Dec 10, 2022
2 parents 7ea25f0 + 1398e40 commit fc906c0
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
12 changes: 8 additions & 4 deletions isort/output.py
Expand Up @@ -209,16 +209,20 @@ def sorted_imports(
break

if config.lines_after_imports != -1:
formatted_output[imports_tail:0] = [
"" for line in range(config.lines_after_imports)
]
lines_after_imports = config.lines_after_imports
if config.profile == "black" and extension == "pyi": # special case for black
lines_after_imports = 1
formatted_output[imports_tail:0] = ["" for line in range(lines_after_imports)]
elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLARATIONS):
formatted_output[imports_tail:0] = ["", ""]
else:
formatted_output[imports_tail:0] = [""]

if config.lines_before_imports != -1:
formatted_output[:0] = ["" for line in range(config.lines_before_imports)]
lines_before_imports = config.lines_before_imports
if config.profile == "black" and extension == "pyi": # special case for black
lines_before_imports = 1
formatted_output[:0] = ["" for line in range(lines_before_imports)]

if parsed.place_imports:
new_out_lines = []
Expand Down
86 changes: 82 additions & 4 deletions tests/unit/profiles/test_black.py
Expand Up @@ -19,20 +19,25 @@ def black_format(code: str, is_pyi: bool = False, line_length: int = 88) -> str:
return code


def black_test(code: str, expected_output: str = ""):
def black_test(code: str, expected_output: str = "", *, is_pyi: bool = False, **config_kwargs):
"""Tests that the given code:
- Behaves the same when formatted multiple times with isort.
- Agrees with black formatting.
- Matches the desired output or itself if none is provided.
"""
expected_output = expected_output or code
config_kwargs = {
"extension": "pyi" if is_pyi else None,
"profile": "black",
**config_kwargs,
}

# output should stay consistent over multiple runs
output = isort.code(code, profile="black")
assert output == isort.code(code, profile="black")
output = isort.code(code, **config_kwargs)
assert output == isort.code(code, **config_kwargs)

# output should agree with black
black_output = black_format(output)
black_output = black_format(output, is_pyi=is_pyi)
assert output == black_output

# output should match expected output
Expand Down Expand Up @@ -369,3 +374,76 @@ def test_black_snippet_three():
DEFAULT_LINE_LENGTH = 88
""",
)


def test_black_pyi_file():
"""Test consistent code formatting between isort and black for `.pyi` files.
black only allows no more than two consecutive blank lines in a `.pyi` file.
"""

black_test(
"""# comment
import math
from typing import Sequence
import numpy as np
def add(a: np.ndarray, b: np.ndarray) -> np.ndarray:
...
def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray:
...
""",
"""# comment
import math
from typing import Sequence
import numpy as np
def add(a: np.ndarray, b: np.ndarray) -> np.ndarray:
...
def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray:
...
""",
is_pyi=False,
lines_before_imports=2,
lines_after_imports=2,
)

black_test(
"""# comment
import math
from typing import Sequence
import numpy as np
def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
""",
"""# comment
import math
from typing import Sequence
import numpy as np
def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
""",
is_pyi=True,
lines_before_imports=2, # will be ignored
lines_after_imports=2, # will be ignored
)

0 comments on commit fc906c0

Please sign in to comment.