diff --git a/CHANGES.md b/CHANGES.md index 4a8ee0e692c..85feb1a7600 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,8 @@ - Fixed Python 3.10 support on platforms without ProcessPoolExecutor (#2631) - Fixed `match` statements with open sequence subjects, like `match a, b:` or `match a, *b:` (#2639) (#2659) +- Fixed `match`/`case` statements that contain `match`/`case` soft keywords multiple + times, like `match re.match()` (#2661) - Fixed assignment to environment variables in Jupyter Notebooks (#2642) - Add `flake8-simplify` and `flake8-comprehensions` plugins (#2653) diff --git a/src/black/linegen.py b/src/black/linegen.py index 4cba4164fb3..f234913a161 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -127,7 +127,7 @@ def visit_stmt( """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, `class`, `assert`, `match`, `case` and assignments. + `def`, `with`, `class`, `assert`, and assignments. The relevant Python language `keywords` for a given statement will be NAME leaves within it. This methods puts those on a separate line. @@ -142,6 +142,14 @@ def visit_stmt( yield from self.visit(child) + def visit_match_case(self, node: Node) -> Iterator[Line]: + """Visit either a match or case statement.""" + normalize_invisible_parens(node, parens_after=set()) + + yield from self.line() + for child in node.children: + yield from self.visit(child) + def visit_suite(self, node: Node) -> Iterator[Line]: """Visit a suite.""" if self.mode.is_pyi and is_stub_suite(node): @@ -294,8 +302,8 @@ def __post_init__(self) -> None: self.visit_decorated = self.visit_decorators # PEP 634 - self.visit_match_stmt = partial(v, keywords={"match"}, parens=Ø) - self.visit_case_block = partial(v, keywords={"case"}, parens=Ø) + self.visit_match_stmt = self.visit_match_case + self.visit_case_block = self.visit_match_case def transform_line( diff --git a/tests/data/pattern_matching_extras.py b/tests/data/pattern_matching_extras.py index 706148561a2..095c1a2b3bb 100644 --- a/tests/data/pattern_matching_extras.py +++ b/tests/data/pattern_matching_extras.py @@ -23,10 +23,10 @@ def func(match: case, case: match) -> case: match Something(): - case another: - ... case func(match, case): ... + case another: + ... match maybe, multiple: @@ -47,6 +47,33 @@ def func(match: case, case: match) -> case: match a, *b, c: case [*_]: - return "seq" + assert "seq" == _ case {}: - return "map" + assert "map" == b + + +match match( + case, + match( + match, case, match, looooooooooooooooooooooooooooooooooooong, match, case, match + ), + case, +): + case case( + match=case, + case=re.match( + loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong + ), + ): + pass + + case [a as match]: + pass + + case case: + pass + + +match match: + case case: + pass