Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return annotation brackets #2990

Merged
merged 12 commits into from Apr 9, 2022
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -14,6 +14,7 @@

<!-- Changes that affect Black's preview style -->

- Parentheses around return annotations are now managed (#2990)
- Remove unnecessary parentheses from `with` statements (#2926)

### _Blackd_
Expand Down
31 changes: 30 additions & 1 deletion src/black/linegen.py
Expand Up @@ -144,6 +144,33 @@ def visit_stmt(

yield from self.visit(child)

def visit_funcdef(self, node: Node) -> Iterator[Line]:
"""Visit function definition."""
if Preview.annotation_parens not in self.mode:
yield from self.visit_stmt(node, keywords={"def"}, parens=set())
else:
yield from self.line()

# Remove redundant brackets around return type annotation.
is_return_annotation = False
for child in node.children:
if child.type == token.RARROW:
is_return_annotation = True
elif is_return_annotation:
if child.type == syms.atom and child.children[0].type == token.LPAR:
if maybe_make_parens_invisible_in_atom(
child,
parent=node,
remove_brackets_around_comma=False,
):
wrap_in_parentheses(node, child, visible=False)
else:
wrap_in_parentheses(node, child, visible=False)
is_return_annotation = False

for child in node.children:
yield from self.visit(child)

def visit_match_case(self, node: Node) -> Iterator[Line]:
"""Visit either a match or case statement."""
normalize_invisible_parens(node, parens_after=set(), preview=self.mode.preview)
Expand Down Expand Up @@ -326,7 +353,6 @@ def __post_init__(self) -> None:
else:
self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
Expand Down Expand Up @@ -478,7 +504,10 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator
current_leaves is body_leaves
and leaf.type in CLOSING_BRACKETS
and leaf.opening_bracket is matching_bracket
and isinstance(matching_bracket, Leaf)
):
ensure_visible(leaf)
ensure_visible(matching_bracket)
current_leaves = tail_leaves if body_leaves else head_leaves
current_leaves.append(leaf)
if current_leaves is head_leaves:
Expand Down
1 change: 1 addition & 0 deletions src/black/mode.py
Expand Up @@ -129,6 +129,7 @@ class Preview(Enum):
string_processing = auto()
remove_redundant_parens = auto()
one_element_subscript = auto()
annotation_parens = auto()


class Deprecated(UserWarning):
Expand Down
222 changes: 222 additions & 0 deletions tests/data/return_annotation_brackets.py
@@ -0,0 +1,222 @@
# Control
def double(a: int) -> int:
return 2*a

# Remove the brackets
def double(a: int) -> (int):
jpy-git marked this conversation as resolved.
Show resolved Hide resolved
return 2*a

# Some newline variations
def double(a: int) -> (
int):
return 2*a

def double(a: int) -> (int
):
return 2*a

def double(a: int) -> (
int
):
return 2*a

# Don't lose the comments
def double(a: int) -> ( # Hello
int
):
return 2*a

def double(a: int) -> (
int # Hello
):
return 2*a

# Really long annotations
def foo() -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2

def foo() -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds:
return 2

def foo() -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds:
return 2

def foo(a: int, b: int, c: int,) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds:
return 2

def foo(a: int, b: int, c: int,) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds | intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds:
return 2

# Split args but no need to split return
def foo(a: int, b: int, c: int,) -> int:
return 2

# Deeply nested brackets
# with *interesting* spacing
jpy-git marked this conversation as resolved.
Show resolved Hide resolved
def double(a: int) -> (((((int))))):
return 2*a

def double(a: int) -> (
( (
((int)
)
)
)
):
return 2*a

def foo() -> (
( (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
)
)):
return 2

# Return type with commas
def foo() -> (
tuple[int, int, int]
):
return 2

def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong, loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong]:
return 2

# Magic trailing comma example
def foo() -> tuple[int, int, int,]:
return 2

# Long string example
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass

# output
# Control
def double(a: int) -> int:
return 2 * a


# Remove the brackets
def double(a: int) -> int:
return 2 * a


# Some newline variations
def double(a: int) -> int:
return 2 * a


def double(a: int) -> int:
return 2 * a


def double(a: int) -> int:
return 2 * a


# Don't lose the comments
def double(a: int) -> int: # Hello
return 2 * a


def double(a: int) -> int: # Hello
return 2 * a


# Really long annotations
def foo() -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2


def foo() -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2


def foo() -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
| intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2


def foo(
a: int,
b: int,
c: int,
) -> intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds:
return 2


def foo(
a: int,
b: int,
c: int,
) -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
| intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2


# Split args but no need to split return
def foo(
a: int,
b: int,
c: int,
) -> int:
return 2


# Deeply nested brackets
# with *interesting* spacing
def double(a: int) -> int:
return 2 * a


def double(a: int) -> int:
return 2 * a


def foo() -> (
intsdfsafafafdfdsasdfsfsdfasdfafdsafdfdsfasdskdsdsfdsafdsafsdfdasfffsfdsfdsafafhdskfhdsfjdslkfdlfsdkjhsdfjkdshfkljds
):
return 2


# Return type with commas
def foo() -> tuple[int, int, int]:
return 2


def foo() -> (
tuple[
loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong,
loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong,
loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong,
]
):
return 2


# Magic trailing comma example
def foo() -> (
tuple[
int,
int,
int,
]
):
return 2


# Long string example
def frobnicate() -> (
"ThisIsTrulyUnreasonablyExtremelyLongClassName |"
" list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
):
pass
1 change: 1 addition & 0 deletions tests/test_format.py
Expand Up @@ -83,6 +83,7 @@
"remove_except_parens",
"remove_for_brackets",
"one_element_subscript",
"return_annotation_brackets",
]

SOURCES: List[str] = [
Expand Down