Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement extracting async with #433

Merged
merged 4 commits into from Oct 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,7 @@
- #391, #376 Fix improper replacement when extracting attribute access expression with `similar=True` (@climbus)
- #396 Fix improper replacement when extracting index access expression with `similar=True` (@lieryan)
- #293 Fix rename global var affects list comprehension (@climbus)
- #387 Implement extract refactoring for code containing `async with`

## Misc

Expand Down
5 changes: 4 additions & 1 deletion rope/refactor/extract.py
Expand Up @@ -386,11 +386,14 @@ def __init__(self, info, matched_lines):
def find_lineno(self):
if self.info.variable and not self.info.make_global:
return self._get_before_line()
if self.info.make_global or self.info.global_:
if self.info.global_:
toplevel = self._find_toplevel(self.info.scope)
ast = self.info.pymodule.get_ast()
newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])
return suites.find_visible(ast, newlines)
if self.info.make_global:
toplevel = self._find_toplevel(self.info.scope)
return toplevel.get_end() + 1
return self._get_after_scope()

def _find_toplevel(self, scope):
Expand Down
9 changes: 7 additions & 2 deletions rope/refactor/patchedast.py
Expand Up @@ -871,9 +871,11 @@ def _While(self, node):
children.extend(node.orelse)
self._handle(node, children)

def _With(self, node):
def _handle_with_node(self, node, is_async):
children = []

if is_async:
children.extend(["async"])
for item in pycompat.get_ast_with_items(node):
children.extend([self.with_or_comma_context_manager, item.context_expr])
if item.optional_vars:
Expand All @@ -885,8 +887,11 @@ def _With(self, node):
children.extend(node.body)
self._handle(node, children)

def _With(self, node):
self._handle_with_node(node, is_async=False)

def _AsyncWith(self, node):
return self._With(node)
self._handle_with_node(node, is_async=True)

def _child_nodes(self, nodes, separator):
children = []
Expand Down
6 changes: 6 additions & 0 deletions rope/refactor/suites.py
Expand Up @@ -114,6 +114,9 @@ def _While(self, node):
def _With(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))

def _AsyncWith(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite))

def _TryFinally(self, node):
proceed_to_except_handler = False
if len(node.finalbody) == 1:
Expand Down Expand Up @@ -153,5 +156,8 @@ def _add_if_like_node(self, node):
def _FunctionDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))

def _AsyncFunctionDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))

def _ClassDef(self, node):
self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True))
40 changes: 39 additions & 1 deletion ropetest/refactor/extracttest.py
Expand Up @@ -1601,7 +1601,7 @@ def new_func():
""")
self.assertEqual(expected, refactored)

def test_extract_method_with_multiple_methods(self): # noqa
def test_global_extract_method_with_multiple_methods(self):
code = dedent("""\
class AClass(object):
def a_func(self):
Expand Down Expand Up @@ -2953,3 +2953,41 @@ def extracted():
extracted()
""")
self.assertEqual(expected, refactored)

@testutils.only_for_versions_higher('3.8')
def test_extract_method_async_with_simple(self):
code = dedent("""\
async def afunc():
async with open("test") as file1:
print(file1)
""")
start, end = self._convert_line_range_to_offset(code, 2, 3)
refactored = self.do_extract_method(code, start, end, 'extracted', global_=True)
expected = dedent("""\
async def afunc():
extracted()

def extracted():
async with open("test") as file1:
print(file1)
""")
self.assertEqual(expected, refactored)

@testutils.only_for_versions_higher('3.8')
def test_extract_method_containing_async_with(self):
code = dedent("""\
async def afunc():
async with open("test") as file1, open("test") as file2:
print(file1, file2)
""")
start, end = self._convert_line_range_to_offset(code, 3, 3)
refactored = self.do_extract_method(code, start, end, 'extracted', global_=True)
expected = dedent("""\
async def afunc():
async with open("test") as file1, open("test") as file2:
extracted(file1, file2)

def extracted(file1, file2):
print(file1, file2)
""")
self.assertEqual(expected, refactored)
14 changes: 14 additions & 0 deletions ropetest/refactor/patchedasttest.py
Expand Up @@ -1095,6 +1095,20 @@ def test_with_node(self):
["with", " ", "Name", " ", "as", " ", "Name", "", ":", "\n ", "Pass"],
)

@testutils.only_for("3.5")
def test_async_with_node(self):
source = dedent("""\
async def afunc():
async with a as b:
pass\n
""")
ast_frag = patchedast.get_patched_ast(source, True)
checker = _ResultChecker(self, ast_frag)
checker.check_children(
"AsyncWith",
["async", " ", "with", " ", "Name", " ", "as", " ", "Name", "", ":", "\n ", "Pass"],
)

def test_try_finally_node(self):
source = dedent("""\
try:
Expand Down