Skip to content

Commit

Permalink
Correctly handle trailing commas that are inside a line's leading non…
Browse files Browse the repository at this point in the history
…-nested parens (#3370)

- Fixes #1671
- Fixes #3229
  • Loading branch information
yilei committed Nov 9, 2022
1 parent ffaaf48 commit 8091b25
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Expand Up @@ -17,6 +17,8 @@
- Enforce empty lines before classes and functions with sticky leading comments (#3302)
- Implicitly concatenated strings used as function args are now wrapped inside
parentheses (#3307)
- Correctly handle trailing commas that are inside a line's leading non-nested parens
(#3370)

### Configuration

Expand Down
34 changes: 33 additions & 1 deletion src/black/brackets.py
Expand Up @@ -2,7 +2,7 @@

import sys
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union

if sys.version_info < (3, 8):
from typing_extensions import Final
Expand Down Expand Up @@ -340,3 +340,35 @@ def max_delimiter_priority_in_atom(node: LN) -> Priority:

except ValueError:
return 0


def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]:
"""Return leaves that are inside matching brackets.
The input `leaves` can have non-matching brackets at the head or tail parts.
Matching brackets are included.
"""
try:
# Only track brackets from the first opening bracket to the last closing
# bracket.
start_index = next(
i for i, l in enumerate(leaves) if l.type in OPENING_BRACKETS
)
end_index = next(
len(leaves) - i
for i, l in enumerate(reversed(leaves))
if l.type in CLOSING_BRACKETS
)
except StopIteration:
return set()
ids = set()
depth = 0
for i in range(end_index, start_index - 1, -1):
leaf = leaves[i]
if leaf.type in CLOSING_BRACKETS:
depth += 1
if depth > 0:
ids.add(id(leaf))
if leaf.type in OPENING_BRACKETS:
depth -= 1
return ids
70 changes: 57 additions & 13 deletions src/black/linegen.py
Expand Up @@ -2,10 +2,16 @@
Generating lines of code.
"""
import sys
from enum import Enum, auto
from functools import partial, wraps
from typing import Collection, Iterator, List, Optional, Set, Union, cast

from black.brackets import COMMA_PRIORITY, DOT_PRIORITY, max_delimiter_priority_in_atom
from black.brackets import (
COMMA_PRIORITY,
DOT_PRIORITY,
get_leaves_inside_matching_brackets,
max_delimiter_priority_in_atom,
)
from black.comments import FMT_OFF, generate_comments, list_comments
from black.lines import (
Line,
Expand Down Expand Up @@ -561,6 +567,12 @@ def _rhs(
yield line


class _BracketSplitComponent(Enum):
head = auto()
body = auto()
tail = auto()


def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Expand Down Expand Up @@ -591,9 +603,15 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator
if not matching_bracket:
raise CannotSplit("No brackets found")

head = bracket_split_build_line(head_leaves, line, matching_bracket)
body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
head = bracket_split_build_line(
head_leaves, line, matching_bracket, component=_BracketSplitComponent.head
)
body = bracket_split_build_line(
body_leaves, line, matching_bracket, component=_BracketSplitComponent.body
)
tail = bracket_split_build_line(
tail_leaves, line, matching_bracket, component=_BracketSplitComponent.tail
)
bracket_split_succeeded_or_raise(head, body, tail)
for result in (head, body, tail):
if result:
Expand Down Expand Up @@ -639,9 +657,15 @@ def right_hand_split(
tail_leaves.reverse()
body_leaves.reverse()
head_leaves.reverse()
head = bracket_split_build_line(head_leaves, line, opening_bracket)
body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
head = bracket_split_build_line(
head_leaves, line, opening_bracket, component=_BracketSplitComponent.head
)
body = bracket_split_build_line(
body_leaves, line, opening_bracket, component=_BracketSplitComponent.body
)
tail = bracket_split_build_line(
tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail
)
bracket_split_succeeded_or_raise(head, body, tail)
if (
Feature.FORCE_OPTIONAL_PARENTHESES not in features
Expand Down Expand Up @@ -715,15 +739,23 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None


def bracket_split_build_line(
leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
leaves: List[Leaf],
original: Line,
opening_bracket: Leaf,
*,
component: _BracketSplitComponent,
) -> Line:
"""Return a new line with given `leaves` and respective comments from `original`.
If `is_body` is True, the result line is one-indented inside brackets and as such
has its first leaf's prefix normalized and a trailing comma added when expected.
If it's the head component, brackets will be tracked so trailing commas are
respected.
If it's the body component, the result line is one-indented inside brackets and as
such has its first leaf's prefix normalized and a trailing comma added when
expected.
"""
result = Line(mode=original.mode, depth=original.depth)
if is_body:
if component is _BracketSplitComponent.body:
result.inside_brackets = True
result.depth += 1
if leaves:
Expand Down Expand Up @@ -761,12 +793,24 @@ def bracket_split_build_line(
leaves.insert(i + 1, new_comma)
break

leaves_to_track: Set[LeafID] = set()
if (
Preview.handle_trailing_commas_in_head in original.mode
and component is _BracketSplitComponent.head
):
leaves_to_track = get_leaves_inside_matching_brackets(leaves)
# Populate the line
for leaf in leaves:
result.append(leaf, preformatted=True)
result.append(
leaf,
preformatted=True,
track_bracket=id(leaf) in leaves_to_track,
)
for comment_after in original.comments_after(leaf):
result.append(comment_after, preformatted=True)
if is_body and should_split_line(result, opening_bracket):
if component is _BracketSplitComponent.body and should_split_line(
result, opening_bracket
):
result.should_split_rhs = True
return result

Expand Down
6 changes: 4 additions & 2 deletions src/black/lines.py
Expand Up @@ -53,7 +53,9 @@ class Line:
should_split_rhs: bool = False
magic_trailing_comma: Optional[Leaf] = None

def append(self, leaf: Leaf, preformatted: bool = False) -> None:
def append(
self, leaf: Leaf, preformatted: bool = False, track_bracket: bool = False
) -> None:
"""Add a new `leaf` to the end of the line.
Unless `preformatted` is True, the `leaf` will receive a new consistent
Expand All @@ -75,7 +77,7 @@ def append(self, leaf: Leaf, preformatted: bool = False) -> None:
leaf.prefix += whitespace(
leaf, complex_subscript=self.is_complex_subscript(leaf)
)
if self.inside_brackets or not preformatted:
if self.inside_brackets or not preformatted or track_bracket:
self.bracket_tracker.mark(leaf)
if self.mode.magic_trailing_comma:
if self.has_magic_trailing_comma(leaf):
Expand Down
1 change: 1 addition & 0 deletions src/black/mode.py
Expand Up @@ -151,6 +151,7 @@ class Preview(Enum):

annotation_parens = auto()
empty_lines_before_class_or_def_with_leading_comments = auto()
handle_trailing_commas_in_head = auto()
long_docstring_quotes_on_newline = auto()
normalize_docstring_quotes_and_prefixes_properly = auto()
one_element_subscript = auto()
Expand Down
40 changes: 40 additions & 0 deletions tests/data/preview/skip_magic_trailing_comma.py
Expand Up @@ -15,6 +15,37 @@
# Except single element tuples
small_tuple = (1,)

# Trailing commas in multiple chained non-nested parens.
zero(
one,
).two(
three,
).four(
five,
)

func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)

(
a,
b,
c,
d,
) = func1(
arg1
) and func2(arg2)

func(
argument1,
(
one,
two,
),
argument4,
argument5,
argument6,
)

# output
# We should not remove the trailing comma in a single-element subscript.
a: tuple[int,]
Expand All @@ -32,3 +63,12 @@

# Except single element tuples
small_tuple = (1,)

# Trailing commas in multiple chained non-nested parens.
zero(one).two(three).four(five)

func1(arg1).func2(arg2).func3(arg3).func4(arg4).func5(arg5)

(a, b, c, d) = func1(arg1) and func2(arg2)

func(argument1, (one, two), argument4, argument5, argument6)
74 changes: 74 additions & 0 deletions tests/data/preview/trailing_commas_in_leading_parts.py
@@ -0,0 +1,74 @@
zero(one,).two(three,).four(five,)

func1(arg1).func2(arg2,).func3(arg3).func4(arg4,).func5(arg5)

# Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)

(a, b, c, d,) = func1(arg1) and func2(arg2)


# Example from https://github.com/psf/black/issues/3229
def refresh_token(self, device_family, refresh_token, api_key):
return self.orchestration.refresh_token(
data={
"refreshToken": refresh_token,
},
api_key=api_key,
)["extensions"]["sdk"]["token"]


# Edge case where a bug in a working-in-progress version of
# https://github.com/psf/black/pull/3370 causes an infinite recursion.
assert (
long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1)
)


# output


zero(
one,
).two(
three,
).four(
five,
)

func1(arg1).func2(
arg2,
).func3(arg3).func4(
arg4,
).func5(arg5)

# Inner one-element tuple shouldn't explode
func1(arg1).func2(arg1, (one_tuple,)).func3(arg3)

(
a,
b,
c,
d,
) = func1(
arg1
) and func2(arg2)


# Example from https://github.com/psf/black/issues/3229
def refresh_token(self, device_family, refresh_token, api_key):
return self.orchestration.refresh_token(
data={
"refreshToken": refresh_token,
},
api_key=api_key,
)["extensions"]["sdk"]["token"]


# Edge case where a bug in a working-in-progress version of
# https://github.com/psf/black/pull/3370 causes an infinite recursion.
assert (
long_module.long_class.long_func().another_func()
== long_module.long_class.long_func()["some_key"].another_func(arg1)
)
29 changes: 29 additions & 0 deletions tests/data/simple_cases/function_trailing_comma.py
Expand Up @@ -49,6 +49,17 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
):
pass


# Make sure inner one-element tuple won't explode
some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6
)

# Inner trailing comma causes outer to explode
some_module.some_function(
argument1, (one, two,), argument4, argument5, argument6
)

# output

def f(
Expand Down Expand Up @@ -151,3 +162,21 @@ def func() -> (
)
):
pass


# Make sure inner one-element tuple won't explode
some_module.some_function(
argument1, (one_element_tuple,), argument4, argument5, argument6
)

# Inner trailing comma causes outer to explode
some_module.some_function(
argument1,
(
one,
two,
),
argument4,
argument5,
argument6,
)

0 comments on commit 8091b25

Please sign in to comment.