diff --git a/CHANGES.md b/CHANGES.md index 30c00566b3c..f66451e2141 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -14,6 +14,7 @@ +- Parentheses around return annotations are now managed (#2990) - Remove unnecessary parentheses from `with` statements (#2926) ### _Blackd_ diff --git a/src/black/linegen.py b/src/black/linegen.py index 2cf9cf3130a..c2b0616d02f 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -144,6 +144,33 @@ def visit_stmt( yield from self.visit(child) + def visit_funcdef(self, node: Node) -> Iterator[Line]: + """Visit function definition.""" + if Preview.annotation_parens not in self.mode: + yield from self.visit_stmt(node, keywords={"def"}, parens=set()) + else: + yield from self.line() + + # Remove redundant brackets around return type annotation. + is_return_annotation = False + for child in node.children: + if child.type == token.RARROW: + is_return_annotation = True + elif is_return_annotation: + if child.type == syms.atom and child.children[0].type == token.LPAR: + if maybe_make_parens_invisible_in_atom( + child, + parent=node, + remove_brackets_around_comma=False, + ): + wrap_in_parentheses(node, child, visible=False) + else: + wrap_in_parentheses(node, child, visible=False) + is_return_annotation = False + + for child in node.children: + yield from self.visit(child) + def visit_match_case(self, node: Node) -> Iterator[Line]: """Visit either a match or case statement.""" normalize_invisible_parens(node, parens_after=set(), preview=self.mode.preview) @@ -326,7 +353,6 @@ def __post_init__(self) -> None: else: self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) - self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) @@ -478,7 +504,10 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator current_leaves is body_leaves and leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is matching_bracket + and isinstance(matching_bracket, Leaf) ): + ensure_visible(leaf) + ensure_visible(matching_bracket) current_leaves = tail_leaves if body_leaves else head_leaves current_leaves.append(leaf) if current_leaves is head_leaves: diff --git a/src/black/mode.py b/src/black/mode.py index 6b74c14b6de..34905702a54 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -129,6 +129,7 @@ class Preview(Enum): string_processing = auto() remove_redundant_parens = auto() one_element_subscript = auto() + annotation_parens = auto() class Deprecated(UserWarning): diff --git a/tests/data/return_annotation_brackets.py b/tests/data/return_annotation_brackets.py new file mode 100644 index 00000000000..27760bd51d7 --- /dev/null +++ b/tests/data/return_annotation_brackets.py @@ -0,0 +1,222 @@ +# Control +def double(a: int) -> int: + return 2*a + +# Remove the brackets +def double(a: int) -> (int): + return 2*a + +# Some newline variations +def double(a: int) -> ( + int): + return 2*a + +def double(a: int) -> (int +): + return 2*a + +def double(a: int) -> ( + int +): + return 2*a + +# Don't lose the comments +def double(a: int) -> ( # Hello + int +): + return 2*a + +def double(a: int) -> ( + int # Hello +): + return 2*a + +# Really long annotations +def foo() -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + +def foo() -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds: + return 2 + +def foo() -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds: + return 2 + +def foo(a: int, b: int, c: int,) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds: + return 2 + +def foo(a: int, b: int, c: int,) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds: + return 2 + +# Split args but no need to split return +def foo(a: int, b: int, c: int,) -> int: + return 2 + +# Deeply nested brackets +# with *interesting* spacing +def double(a: int) -> (((((int))))): + return 2*a + +def double(a: int) -> ( + ( ( + ((int) + ) + ) + ) + ): + return 2*a + +def foo() -> ( + ( ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +) +)): + return 2 + +# Return type with commas +def foo() -> ( + tuple[int, int, int] +): + return 2 + +def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong]: + return 2 + +# Magic trailing comma example +def foo() -> tuple[int, int, int,]: + return 2 + +# Long string example +def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]": + pass + +# output +# Control +def double(a: int) -> int: + return 2 * a + + +# Remove the brackets +def double(a: int) -> int: + return 2 * a + + +# Some newline variations +def double(a: int) -> int: + return 2 * a + + +def double(a: int) -> int: + return 2 * a + + +def double(a: int) -> int: + return 2 * a + + +# Don't lose the comments +def double(a: int) -> int: # Hello + return 2 * a + + +def double(a: int) -> int: # Hello + return 2 * a + + +# Really long annotations +def foo() -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + + +def foo() -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + + +def foo() -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds + | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + + +def foo( + a: int, + b: int, + c: int, +) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds: + return 2 + + +def foo( + a: int, + b: int, + c: int, +) -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds + | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + + +# Split args but no need to split return +def foo( + a: int, + b: int, + c: int, +) -> int: + return 2 + + +# Deeply nested brackets +# with *interesting* spacing +def double(a: int) -> int: + return 2 * a + + +def double(a: int) -> int: + return 2 * a + + +def foo() -> ( + intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds +): + return 2 + + +# Return type with commas +def foo() -> tuple[int, int, int]: + return 2 + + +def foo() -> ( + tuple[ + loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, + loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, + loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, + ] +): + return 2 + + +# Magic trailing comma example +def foo() -> ( + tuple[ + int, + int, + int, + ] +): + return 2 + + +# Long string example +def frobnicate() -> ( + "ThisIsTrulyUnreasonablyExtremelyLongClassName |" + " list[ThisIsTrulyUnreasonablyExtremelyLongClassName]" +): + pass diff --git a/tests/test_format.py b/tests/test_format.py index d80eaa730cd..6f71617eee6 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -83,6 +83,7 @@ "remove_except_parens", "remove_for_brackets", "one_element_subscript", + "return_annotation_brackets", ] SOURCES: List[str] = [