diff --git a/CHANGELOG.md b/CHANGELOG.md index b4bffb821..335c7fade 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - #391, #376 Fix improper replacement when extracting attribute access expression with `similar=True` (@climbus) - #396 Fix improper replacement when extracting index access expression with `similar=True` (@lieryan) - #293 Fix rename global var affects list comprehension (@climbus) +- #387 Implement extract refactoring for code containing `async with` ## Misc diff --git a/rope/refactor/extract.py b/rope/refactor/extract.py index 59d9bb43e..16b7d9b5d 100644 --- a/rope/refactor/extract.py +++ b/rope/refactor/extract.py @@ -386,11 +386,14 @@ def __init__(self, info, matched_lines): def find_lineno(self): if self.info.variable and not self.info.make_global: return self._get_before_line() - if self.info.make_global or self.info.global_: + if self.info.global_: toplevel = self._find_toplevel(self.info.scope) ast = self.info.pymodule.get_ast() newlines = sorted(self.matched_lines + [toplevel.get_end() + 1]) return suites.find_visible(ast, newlines) + if self.info.make_global: + toplevel = self._find_toplevel(self.info.scope) + return toplevel.get_end() + 1 return self._get_after_scope() def _find_toplevel(self, scope): diff --git a/rope/refactor/patchedast.py b/rope/refactor/patchedast.py index d8ed4e063..fda034e89 100644 --- a/rope/refactor/patchedast.py +++ b/rope/refactor/patchedast.py @@ -871,9 +871,11 @@ def _While(self, node): children.extend(node.orelse) self._handle(node, children) - def _With(self, node): + def _handle_with_node(self, node, is_async): children = [] + if is_async: + children.extend(["async"]) for item in pycompat.get_ast_with_items(node): children.extend([self.with_or_comma_context_manager, item.context_expr]) if item.optional_vars: @@ -885,8 +887,11 @@ def _With(self, node): children.extend(node.body) self._handle(node, children) + def _With(self, node): + self._handle_with_node(node, is_async=False) + def _AsyncWith(self, node): - return self._With(node) + self._handle_with_node(node, is_async=True) def _child_nodes(self, nodes, separator): children = [] diff --git a/rope/refactor/suites.py b/rope/refactor/suites.py index b7f5c9a66..a6597f87a 100644 --- a/rope/refactor/suites.py +++ b/rope/refactor/suites.py @@ -114,6 +114,9 @@ def _While(self, node): def _With(self, node): self.suites.append(Suite(node.body, node.lineno, self.suite)) + def _AsyncWith(self, node): + self.suites.append(Suite(node.body, node.lineno, self.suite)) + def _TryFinally(self, node): proceed_to_except_handler = False if len(node.finalbody) == 1: @@ -153,5 +156,8 @@ def _add_if_like_node(self, node): def _FunctionDef(self, node): self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True)) + def _AsyncFunctionDef(self, node): + self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True)) + def _ClassDef(self, node): self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True)) diff --git a/ropetest/refactor/extracttest.py b/ropetest/refactor/extracttest.py index abf575df1..d11548025 100644 --- a/ropetest/refactor/extracttest.py +++ b/ropetest/refactor/extracttest.py @@ -1601,7 +1601,7 @@ def new_func(): """) self.assertEqual(expected, refactored) - def test_extract_method_with_multiple_methods(self): # noqa + def test_global_extract_method_with_multiple_methods(self): code = dedent("""\ class AClass(object): def a_func(self): @@ -2953,3 +2953,41 @@ def extracted(): extracted() """) self.assertEqual(expected, refactored) + + @testutils.only_for_versions_higher('3.8') + def test_extract_method_async_with_simple(self): + code = dedent("""\ + async def afunc(): + async with open("test") as file1: + print(file1) + """) + start, end = self._convert_line_range_to_offset(code, 2, 3) + refactored = self.do_extract_method(code, start, end, 'extracted', global_=True) + expected = dedent("""\ + async def afunc(): + extracted() + + def extracted(): + async with open("test") as file1: + print(file1) + """) + self.assertEqual(expected, refactored) + + @testutils.only_for_versions_higher('3.8') + def test_extract_method_containing_async_with(self): + code = dedent("""\ + async def afunc(): + async with open("test") as file1, open("test") as file2: + print(file1, file2) + """) + start, end = self._convert_line_range_to_offset(code, 3, 3) + refactored = self.do_extract_method(code, start, end, 'extracted', global_=True) + expected = dedent("""\ + async def afunc(): + async with open("test") as file1, open("test") as file2: + extracted(file1, file2) + + def extracted(file1, file2): + print(file1, file2) + """) + self.assertEqual(expected, refactored) diff --git a/ropetest/refactor/patchedasttest.py b/ropetest/refactor/patchedasttest.py index 1bae47b22..479e46a7a 100644 --- a/ropetest/refactor/patchedasttest.py +++ b/ropetest/refactor/patchedasttest.py @@ -1095,6 +1095,20 @@ def test_with_node(self): ["with", " ", "Name", " ", "as", " ", "Name", "", ":", "\n ", "Pass"], ) + @testutils.only_for("3.5") + def test_async_with_node(self): + source = dedent("""\ + async def afunc(): + async with a as b: + pass\n + """) + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_children( + "AsyncWith", + ["async", " ", "with", " ", "Name", " ", "as", " ", "Name", "", ":", "\n ", "Pass"], + ) + def test_try_finally_node(self): source = dedent("""\ try: