From d6e573f2311fb76bf8fb17d42424d65261222e85 Mon Sep 17 00:00:00 2001 From: Lorenzo Gasparini Date: Thu, 23 Jan 2020 13:33:40 +0100 Subject: [PATCH] Fix #1202: remove trailing comma from function arguments list --- black.py | 51 +++++++++++++++++++++------ blib2to3/pgen2/driver.py | 2 +- tests/data/collections.py | 4 +-- tests/data/comments7.py | 2 +- tests/data/expression.diff | 2 +- tests/data/expression.py | 2 +- tests/data/function.py | 2 +- tests/data/function2.py | 2 +- tests/data/function_trailing_comma.py | 17 +++++++-- 9 files changed, 63 insertions(+), 21 deletions(-) diff --git a/black.py b/black.py index 210120bae6c..3749c9006dc 100644 --- a/black.py +++ b/black.py @@ -1455,19 +1455,12 @@ def contains_multiline_strings(self) -> bool: return False - def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: - """Remove trailing comma if there is one and it's safe.""" - if not (self.leaves and self.leaves[-1].type == token.COMMA): - return False - - # We remove trailing commas only in the case of importing a - # single name from a module. - if not ( + def is_single_name_module_import(self, closing: Leaf) -> bool: + if ( self.leaves and self.is_import and len(self.leaves) > 4 and self.leaves[-1].type == token.COMMA - and closing.type in CLOSING_BRACKETS and self.leaves[-4].type == token.NAME and ( # regular `from foo import bar,` @@ -1487,10 +1480,46 @@ def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: ) and closing.type == token.RPAR ): + return True + + return False + + def is_function_argument_list(self, closing: Leaf) -> bool: + depth = closing.bracket_depth + 1 + opening = closing.opening_bracket + + try: + _opening_index = self.leaves.index(opening) + except ValueError: return False - self.remove_trailing_comma() - return True + for leaf in self.leaves[_opening_index + 1 :]: + if leaf is closing: + break + + if ( + leaf.bracket_depth == depth + and leaf.type == token.COMMA + and leaf.parent + and leaf.parent.type in {syms.arglist, syms.typedargslist} + ): + return True + + return False + + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: + """Remove trailing comma if there is one and it's safe.""" + if not (self.leaves and self.leaves[-1].type == token.COMMA): + return False + + if closing.type in CLOSING_BRACKETS and ( + self.is_single_name_module_import(closing) + or self.is_function_argument_list(closing) + ): + self.remove_trailing_comma() + return True + + return False def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" diff --git a/blib2to3/pgen2/driver.py b/blib2to3/pgen2/driver.py index 052c94883cf..81940f78f0f 100644 --- a/blib2to3/pgen2/driver.py +++ b/blib2to3/pgen2/driver.py @@ -128,7 +128,7 @@ def parse_stream(self, stream: IO[Text], debug: bool = False) -> NL: return self.parse_stream_raw(stream, debug) def parse_file( - self, filename: Path, encoding: Optional[Text] = None, debug: bool = False, + self, filename: Path, encoding: Optional[Text] = None, debug: bool = False ) -> NL: """Parse a file and return the syntax tree.""" with io.open(filename, "r", encoding=encoding) as stream: diff --git a/tests/data/collections.py b/tests/data/collections.py index ebe8d3c5200..40661b3373a 100644 --- a/tests/data/collections.py +++ b/tests/data/collections.py @@ -154,8 +154,8 @@ InstanceIds=[instance.id], WaiterConfig={"Delay": 5,} ) ec2client.get_waiter("instance_stopped").wait( - InstanceIds=[instance.id], WaiterConfig={"Delay": 5,}, + InstanceIds=[instance.id], WaiterConfig={"Delay": 5,} ) ec2client.get_waiter("instance_stopped").wait( - InstanceIds=[instance.id], WaiterConfig={"Delay": 5,}, + InstanceIds=[instance.id], WaiterConfig={"Delay": 5,} ) diff --git a/tests/data/comments7.py b/tests/data/comments7.py index 40951253f2e..ffc12a7f901 100644 --- a/tests/data/comments7.py +++ b/tests/data/comments7.py @@ -93,7 +93,7 @@ def func(): def func(): - c = call(0.0123, 0.0456, 0.0789, 0.0123, 0.0789, a[-1],) # type: ignore + c = call(0.0123, 0.0456, 0.0789, 0.0123, 0.0789, a[-1]) # type: ignore # The type: ignore exception only applies to line length, not # other types of formatting. diff --git a/tests/data/expression.diff b/tests/data/expression.diff index 629e1012f87..47493725533 100644 --- a/tests/data/expression.diff +++ b/tests/data/expression.diff @@ -212,7 +212,7 @@ + .filter( + models.Customer.account_id == account_id, models.Customer.email == email_address + ) -+ .order_by(models.Customer.id.asc(),) ++ .order_by(models.Customer.id.asc()) + .all() +) Ø = set() diff --git a/tests/data/expression.py b/tests/data/expression.py index 3bcf52b54c4..1163d340a63 100644 --- a/tests/data/expression.py +++ b/tests/data/expression.py @@ -459,7 +459,7 @@ async def f(): .filter( models.Customer.account_id == account_id, models.Customer.email == email_address ) - .order_by(models.Customer.id.asc(),) + .order_by(models.Customer.id.asc()) .all() ) Ø = set() diff --git a/tests/data/function.py b/tests/data/function.py index 51234a1e9b4..4754588e38d 100644 --- a/tests/data/function.py +++ b/tests/data/function.py @@ -230,7 +230,7 @@ def trailing_comma(): } -def f(a, **kwargs,) -> A: +def f(a, **kwargs) -> A: return ( yield from A( very_long_argument_name1=very_long_value_for_the_argument, diff --git a/tests/data/function2.py b/tests/data/function2.py index a6773d429cd..e08f62df1fa 100644 --- a/tests/data/function2.py +++ b/tests/data/function2.py @@ -25,7 +25,7 @@ def inner(): # output -def f(a, **kwargs,) -> A: +def f(a, **kwargs) -> A: with cache_dir(): if something: result = CliRunner().invoke( diff --git a/tests/data/function_trailing_comma.py b/tests/data/function_trailing_comma.py index fcd81ad7d96..280c80b550f 100644 --- a/tests/data/function_trailing_comma.py +++ b/tests/data/function_trailing_comma.py @@ -9,13 +9,21 @@ def xxxxxxxxxxxxxxxxxxxxxxxxxxxx() -> Set[ ]: pass +def _gopass_process() -> Popen: + """Spawn a Gopass process""" + return Popen( + ['gopass', 'jsonapi', 'listen'], + stdout=PIPE, + stdin=PIPE, + ) + # output -def f(a,): +def f(a): ... -def f(a: int = 1,): +def f(a: int = 1): ... @@ -23,3 +31,8 @@ def xxxxxxxxxxxxxxxxxxxxxxxxxxxx() -> Set[ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ]: pass + + +def _gopass_process() -> Popen: + """Spawn a Gopass process""" + return Popen(["gopass", "jsonapi", "listen"], stdout=PIPE, stdin=PIPE)