diff --git a/rope/refactor/extract.py b/rope/refactor/extract.py index ff46e5a89..923f47248 100644 --- a/rope/refactor/extract.py +++ b/rope/refactor/extract.py @@ -670,16 +670,42 @@ def _find_function_returns(self): def _get_unindented_function_body(self, returns): if self.info.one_line: - if self.info.returning_named_expr: - return 'return ' + '(' + _join_lines(self.info.extracted) + ')' - else: - return 'return ' + _join_lines(self.info.extracted) - extracted_body = self.info.extracted - unindented_body = sourceutils.fix_indentation(extracted_body, 0) + return self._get_one_line_function_body() + return self._get_multiline_function_body(returns) + + def _get_multiline_function_body(self, returns): + unindented_body = sourceutils.fix_indentation(self.info.extracted, 0) + unindented_body = self._insert_globals(unindented_body) if returns: unindented_body += '\nreturn %s' % self._get_comma_form(returns) return unindented_body + def _get_one_line_function_body(self): + if self.info.returning_named_expr: + body = 'return ' + '(' + _join_lines(self.info.extracted) + ')' + else: + body = 'return ' + _join_lines(self.info.extracted) + return self._insert_globals(body) + + def _insert_globals(self, unindented_body): + globals_in_body = self._get_globals_in_body(unindented_body) + globals_ = self.info_collector.globals_ & self._get_internal_variables() + globals_ = globals_ - globals_in_body + + if globals_: + unindented_body = "global {}\n{}".format(", ".join(globals_), unindented_body) + return unindented_body + + @staticmethod + def _get_globals_in_body(unindented_body): + node = _parse_text(unindented_body) + visitor = _GlobalFinder() + ast.walk(node, visitor) + return visitor.globals_ + + def _get_internal_variables(self): + return self.info_collector.read | self.info_collector.written | self.info_collector.maybe_written + class _ExtractVariableParts(object): @@ -715,6 +741,10 @@ def __init__(self, start, end, is_global): self.postwritten = OrderedSet() self.host_function = True self.conditional = False + self.surrounded_by_loop = 0 + self.globals_ = OrderedSet() + self.surrounded_by_loop = 0 + self.loop_depth = 0 self.loop_depth = 0 def _read_variable(self, name, lineno): @@ -754,6 +784,9 @@ def _FunctionDef(self, node): for name in visitor.read - visitor.written: self._read_variable(name, node.lineno) + def _Global(self, node): + self.globals_.add(*node.names) + def _AsyncFunctionDef(self, node): self._FunctionDef(node) @@ -939,6 +972,14 @@ def _ClassDef(self, node): pass +class _GlobalFinder(object): + def __init__(self): + self.globals_ = OrderedSet() + + def _Global(self, node): + self.globals_.add(*node.names) + + def _get_function_kind(scope): return scope.pyobject.get_kind() diff --git a/ropetest/refactor/extracttest.py b/ropetest/refactor/extracttest.py index d824a5c88..32005d9eb 100644 --- a/ropetest/refactor/extracttest.py +++ b/ropetest/refactor/extracttest.py @@ -1811,6 +1811,101 @@ def second_method(someargs): self.assertEqual(expected, refactored) + def test_extraction_method_with_global_variable(self): + code = dedent('''\ + g = None + + def f(): + global g + + g = 2 + + f() + print(g) + ''') + extract_target = 'g = 2' + start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) + refactored = self.do_extract_method(code, start, end, '_g') + expected = dedent('''\ + g = None + + def f(): + global g + + _g() + + def _g(): + global g + g = 2 + + f() + print(g) + ''') + self.assertEqual(expected, refactored) + + def test_extraction_method_with_global_variable_and_global_declaration(self): + code = dedent('''\ + g = None + + def f(): + global g + + g = 2 + + f() + print(g) + ''') + start, end = 23, 42 + refactored = self.do_extract_method(code, start, end, '_g') + expected = dedent('''\ + g = None + + def f(): + _g() + + def _g(): + global g + + g = 2 + + f() + print(g) + ''') + self.assertEqual(expected, refactored) + + def test_extraction_one_line_with_global_variable(self): + code = dedent('''\ + g = None + + def f(): + global g + + a = g + + f() + print(g) + ''') + extract_target = '= g' + start, end = code.index(extract_target) + 2, code.index(extract_target) + 3 + refactored = self.do_extract_method(code, start, end, '_g') + print(refactored) + expected = dedent('''\ + g = None + + def f(): + global g + + a = _g() + + def _g(): + global g + return g + + f() + print(g) + ''') + self.assertEqual(expected, refactored) + if __name__ == '__main__': unittest.main()