diff --git a/isort/output.py b/isort/output.py index c59be936d..3cb3c08b0 100644 --- a/isort/output.py +++ b/isort/output.py @@ -209,16 +209,20 @@ def sorted_imports( break if config.lines_after_imports != -1: - formatted_output[imports_tail:0] = [ - "" for line in range(config.lines_after_imports) - ] + lines_after_imports = config.lines_after_imports + if config.profile == "black" and extension == "pyi": # special case for black + lines_after_imports = 1 + formatted_output[imports_tail:0] = ["" for line in range(lines_after_imports)] elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLARATIONS): formatted_output[imports_tail:0] = ["", ""] else: formatted_output[imports_tail:0] = [""] if config.lines_before_imports != -1: - formatted_output[:0] = ["" for line in range(config.lines_before_imports)] + lines_before_imports = config.lines_before_imports + if config.profile == "black" and extension == "pyi": # special case for black + lines_before_imports = 1 + formatted_output[:0] = ["" for line in range(lines_before_imports)] if parsed.place_imports: new_out_lines = [] diff --git a/tests/unit/profiles/test_black.py b/tests/unit/profiles/test_black.py index 05444b8cc..a4bb2d52a 100644 --- a/tests/unit/profiles/test_black.py +++ b/tests/unit/profiles/test_black.py @@ -19,20 +19,25 @@ def black_format(code: str, is_pyi: bool = False, line_length: int = 88) -> str: return code -def black_test(code: str, expected_output: str = ""): +def black_test(code: str, expected_output: str = "", *, is_pyi: bool = False, **config_kwargs): """Tests that the given code: - Behaves the same when formatted multiple times with isort. - Agrees with black formatting. - Matches the desired output or itself if none is provided. """ expected_output = expected_output or code + config_kwargs = { + "extension": "pyi" if is_pyi else None, + "profile": "black", + **config_kwargs, + } # output should stay consistent over multiple runs - output = isort.code(code, profile="black") - assert output == isort.code(code, profile="black") + output = isort.code(code, **config_kwargs) + assert output == isort.code(code, **config_kwargs) # output should agree with black - black_output = black_format(output) + black_output = black_format(output, is_pyi=is_pyi) assert output == black_output # output should match expected output @@ -369,3 +374,76 @@ def test_black_snippet_three(): DEFAULT_LINE_LENGTH = 88 """, ) + + +def test_black_pyi_file(): + """Test consistent code formatting between isort and black for `.pyi` files. + + black only allows no more than two consecutive blank lines in a `.pyi` file. + """ + + black_test( + """# comment + + +import math + +from typing import Sequence +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... + + +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... +""", + """# comment + + +import math +from typing import Sequence + +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... + + +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... +""", + is_pyi=False, + lines_before_imports=2, + lines_after_imports=2, + ) + + black_test( + """# comment + + +import math + +from typing import Sequence +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +""", + """# comment + +import math +from typing import Sequence + +import numpy as np + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +""", + is_pyi=True, + lines_before_imports=2, # will be ignored + lines_after_imports=2, # will be ignored + )