diff --git a/CHANGELOG.md b/CHANGELOG.md index a977cb797..c5aa37b52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - #533 Refactoring to Remove usage of unicode type - #559 Improve handling of whitespace in import and from-import statements +- #581 Remove functions in rope.base.ast that has functionally identical implementation in stdlib's ast # Release 1.5.1 diff --git a/rope/base/ast.py b/rope/base/ast.py index efe95fcbc..f64a5b3f1 100644 --- a/rope/base/ast.py +++ b/rope/base/ast.py @@ -22,47 +22,17 @@ def parse(source, filename=""): raise error -def walk(node, walker) -> None: - """Walk the syntax tree""" - method_name = "_" + node.__class__.__name__ - method = getattr(walker, method_name, None) - if method is not None: - method(node) - return - for child in get_child_nodes(node): - walk(child, walker) - - -def get_child_nodes(node): - if isinstance(node, ast.Module): - return node.body - result = [] - if node._fields is not None: - for name in node._fields: - child = getattr(node, name) - if isinstance(child, list): - for entry in child: - if isinstance(entry, ast.AST): - result.append(entry) - if isinstance(child, ast.AST): - result.append(child) - return result - - def call_for_nodes(node, callback, recursive=False): """If callback returns `True` the child nodes are skipped""" result = callback(node) if recursive and not result: - for child in get_child_nodes(node): + for child in ast.iter_child_nodes(node): call_for_nodes(child, callback, recursive) -def get_children(node): - result = [] - if node._fields is not None: - for name in node._fields: - if name in ["lineno", "col_offset"]: - continue - child = getattr(node, name) - result.append(child) - return result +class RopeNodeVisitor(ast.NodeVisitor): + def visit(self, node): + """Modified from ast.NodeVisitor to match rope's existing Visitor implementation""" + method = "_" + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) diff --git a/rope/base/astutils.py b/rope/base/astutils.py index 3eef2ebb6..fd7f1d2ea 100644 --- a/rope/base/astutils.py +++ b/rope/base/astutils.py @@ -14,11 +14,11 @@ def get_name_levels(node): """ visitor = _NodeNameCollector() - ast.walk(node, visitor) + visitor.visit(node) return visitor.names -class _NodeNameCollector: +class _NodeNameCollector(ast.RopeNodeVisitor): def __init__(self, levels=None): self.names = [] self.levels = levels @@ -49,8 +49,8 @@ def _Tuple(self, node): new_levels.append(self.index) self.index += 1 visitor = _NodeNameCollector(new_levels) - for child in ast.get_child_nodes(node): - ast.walk(child, visitor) + for child in ast.iter_child_nodes(node): + visitor.visit(child) self.names.extend(visitor.names) def _Subscript(self, node): diff --git a/rope/base/evaluate.py b/rope/base/evaluate.py index 4aa9fdbe0..519ebc4d7 100644 --- a/rope/base/evaluate.py +++ b/rope/base/evaluate.py @@ -39,7 +39,7 @@ def eval_node(scope, node): def eval_node2(scope, node): evaluator = StatementEvaluator(scope) - ast.walk(node, evaluator) + evaluator.visit(node) return evaluator.old_result, evaluator.result @@ -158,7 +158,7 @@ def _find_module(self, module_name): ) -class StatementEvaluator: +class StatementEvaluator(ast.RopeNodeVisitor): def __init__(self, scope): self.scope = scope self.result = None diff --git a/rope/base/oi/soa.py b/rope/base/oi/soa.py index 4e8cab4a8..fc3a9dbac 100644 --- a/rope/base/oi/soa.py +++ b/rope/base/oi/soa.py @@ -33,11 +33,11 @@ def _follow(pyfunction): if not followed_calls: _follow = None visitor = SOAVisitor(pycore, pydefined, _follow) - for child in rope.base.ast.get_child_nodes(pydefined.get_ast()): - rope.base.ast.walk(child, visitor) + for child in rope.base.ast.iter_child_nodes(pydefined.get_ast()): + visitor.visit(child) -class SOAVisitor: +class SOAVisitor(rope.base.ast.RopeNodeVisitor): def __init__(self, pycore, pydefined, follow_callback=None): self.pycore = pycore self.pymodule = pydefined.get_module() @@ -51,8 +51,8 @@ def _ClassDef(self, node): pass def _Call(self, node): - for child in rope.base.ast.get_child_nodes(node): - rope.base.ast.walk(child, self) + for child in rope.base.ast.iter_child_nodes(node): + self.visit(child) primary, pyname = evaluate.eval_node2(self.scope, node.func) if pyname is None: return @@ -99,23 +99,23 @@ def _parameter_objects(self, pyfunction): ] def _AnnAssign(self, node): - for child in rope.base.ast.get_child_nodes(node): - rope.base.ast.walk(child, self) + for child in rope.base.ast.iter_child_nodes(node): + self.visit(child) visitor = _SOAAssignVisitor() nodes = [] - rope.base.ast.walk(node.target, visitor) + visitor.visit(node.target) nodes.extend(visitor.nodes) self._evaluate_assign_value(node, nodes, type_hint=node.annotation) def _Assign(self, node): - for child in rope.base.ast.get_child_nodes(node): - rope.base.ast.walk(child, self) + for child in rope.base.ast.iter_child_nodes(node): + self.visit(child) visitor = _SOAAssignVisitor() nodes = [] for child in node.targets: - rope.base.ast.walk(child, visitor) + visitor.visit(child) nodes.extend(visitor.nodes) self._evaluate_assign_value(node, nodes) diff --git a/rope/base/pyobjects.py b/rope/base/pyobjects.py index 820fc2d1b..62702e7a3 100644 --- a/rope/base/pyobjects.py +++ b/rope/base/pyobjects.py @@ -243,8 +243,8 @@ def _create_structural_attributes(self): if self.visitor_class is None: return {} new_visitor = self.visitor_class(self.pycore, self) - for child in ast.get_child_nodes(self.ast_node): - ast.walk(child, new_visitor) + for child in ast.iter_child_nodes(self.ast_node): + new_visitor.visit(child) self.defineds = new_visitor.defineds return new_visitor.names diff --git a/rope/base/pyobjectsdef.py b/rope/base/pyobjectsdef.py index e0f0efa75..92e9e78c6 100644 --- a/rope/base/pyobjectsdef.py +++ b/rope/base/pyobjectsdef.py @@ -291,7 +291,7 @@ def get_name(self): return rope.base.libutils.modname(self.resource) if self.resource else "" -class _AnnAssignVisitor: +class _AnnAssignVisitor(ast.RopeNodeVisitor): def __init__(self, scope_visitor): self.scope_visitor = scope_visitor self.assigned_ast = None @@ -301,7 +301,7 @@ def _AnnAssign(self, node): self.assigned_ast = node.value self.type_hint = node.annotation - ast.walk(node.target, self) + self.visit(node.target) def _assigned(self, name, assignment=None): self.scope_visitor._assigned(name, assignment) @@ -333,7 +333,7 @@ def _Slice(self, node): pass -class _ExpressionVisitor: +class _ExpressionVisitor(ast.RopeNodeVisitor): def __init__(self, scope_visitor): self.scope_visitor = scope_visitor @@ -356,11 +356,11 @@ def _DictComp(self, node): self._GeneratorExp(node) def _NamedExpr(self, node): - ast.walk(node.target, _AssignVisitor(self)) - ast.walk(node.value, self) + _AssignVisitor(self).visit(node.target) + self.visit(node.value) -class _AssignVisitor: +class _AssignVisitor(ast.RopeNodeVisitor): def __init__(self, scope_visitor): self.scope_visitor = scope_visitor self.assigned_ast = None @@ -368,8 +368,8 @@ def __init__(self, scope_visitor): def _Assign(self, node): self.assigned_ast = node.value for child_node in node.targets: - ast.walk(child_node, self) - ast.walk(node.value, _ExpressionVisitor(self.scope_visitor)) + self.visit(child_node) + _ExpressionVisitor(self.scope_visitor).visit(node.value) def _assigned(self, name, assignment=None): self.scope_visitor._assigned(name, assignment) @@ -446,10 +446,10 @@ def _AsyncFunctionDef(self, node): return self._FunctionDef(node) def _Assign(self, node): - ast.walk(node, _AssignVisitor(self)) + _AssignVisitor(self).visit(node) def _AnnAssign(self, node): - ast.walk(node, _AnnAssignVisitor(self)) + _AnnAssignVisitor(self).visit(node) def _AugAssign(self, node): pass @@ -457,7 +457,7 @@ def _AugAssign(self, node): def _For(self, node): self._update_evaluated(node.target, node.iter, ".__iter__().next()") for child in node.body + node.orelse: - ast.walk(child, self) + self.visit(child) def _AsyncFor(self, node): return self._For(node) @@ -494,7 +494,7 @@ def _With(self, node): item.optional_vars, item.context_expr, ".__enter__()" ) for child in node.body: - ast.walk(child, self) + self.visit(child) def _AsyncWith(self, node): return self._With(node) @@ -508,7 +508,7 @@ def _excepthandler(self, node): self._update_evaluated(node.name, type_node, eval_type=True) for child in node.body: - ast.walk(child, self) + self.visit(child) def _ExceptHandler(self, node): self._excepthandler(node) @@ -571,8 +571,8 @@ def _Global(self, node): class _ComprehensionVisitor(_ScopeVisitor): def _comprehension(self, node): - ast.walk(node.target, self) - ast.walk(node.iter, self) + self.visit(node.target) + self.visit(node.iter) def _Name(self, node): if isinstance(node.ctx, ast.Store): @@ -599,8 +599,8 @@ def _FunctionDef(self, node): if isinstance(first, ast.arg): new_visitor = _ClassInitVisitor(self, first.arg) if new_visitor is not None: - for child in ast.get_child_nodes(node): - ast.walk(child, new_visitor) + for child in ast.iter_child_nodes(node): + new_visitor.visit(child) class _FunctionVisitor(_ScopeVisitor): @@ -642,8 +642,8 @@ def _Attribute(self, node): def _Tuple(self, node): if not isinstance(node.ctx, ast.Store): return - for child in ast.get_child_nodes(node): - ast.walk(child, self) + for child in ast.iter_child_nodes(node): + self.visit(child) def _Name(self, node): pass diff --git a/rope/base/pyscopes.py b/rope/base/pyscopes.py index 74e16b415..94d5b82e5 100644 --- a/rope/base/pyscopes.py +++ b/rope/base/pyscopes.py @@ -186,8 +186,8 @@ def get_names(self): def _visit_comprehension(self): if self.names is None: new_visitor = self.visitor(self.pycore, self.pyobject) - for node in ast.get_child_nodes(self.pyobject.get_ast()): - ast.walk(node, new_visitor) + for node in ast.iter_child_nodes(self.pyobject.get_ast()): + new_visitor.visit(node) self.names = dict(self.parent.get_names()) self.names.update(new_visitor.names) self.defineds = new_visitor.defineds @@ -218,8 +218,8 @@ def _get_names(self): def _visit_function(self): if self.names is None: new_visitor = self.visitor(self.pycore, self.pyobject) - for n in ast.get_child_nodes(self.pyobject.get_ast()): - ast.walk(n, new_visitor) + for n in ast.iter_child_nodes(self.pyobject.get_ast()): + new_visitor.visit(n) self.names = new_visitor.names self.names.update(self.pyobject.get_parameters()) self.returned_asts = new_visitor.returned_asts diff --git a/rope/contrib/finderrors.py b/rope/contrib/finderrors.py index 9d533195b..335e6d84a 100644 --- a/rope/contrib/finderrors.py +++ b/rope/contrib/finderrors.py @@ -33,11 +33,11 @@ def find_errors(project, resource): """ pymodule = project.get_pymodule(resource) finder = _BadAccessFinder(pymodule) - ast.walk(pymodule.get_ast(), finder) + finder.visit(pymodule.get_ast()) return finder.errors -class _BadAccessFinder: +class _BadAccessFinder(ast.RopeNodeVisitor): def __init__(self, pymodule): self.pymodule = pymodule self.scope = pymodule.get_scope() @@ -60,7 +60,7 @@ def _Attribute(self, node): if pyname is not None and pyname.get_object() != pyobjects.get_unknown(): if node.attr not in pyname.get_object(): self._add_error(node, "Unresolved attribute") - ast.walk(node.value, self) + self.visit(node.value) def _add_error(self, node, msg): if isinstance(node, ast.Attribute): diff --git a/rope/refactor/extract.py b/rope/refactor/extract.py index 7c9edda64..47628c75c 100644 --- a/rope/refactor/extract.py +++ b/rope/refactor/extract.py @@ -514,7 +514,7 @@ def _is_on_a_word(self, info, offset): return next.isalnum() or next == "_" -class _ExtractMethodParts: +class _ExtractMethodParts(ast.RopeNodeVisitor): def __init__(self, info): self.info = info self.info_collector = self._create_info_collector() @@ -589,7 +589,7 @@ def _create_info_collector(self): ) body = self.info.source[self.info.scope_region[0] : self.info.scope_region[1]] node = _parse_text(body) - ast.walk(node, info_collector) + info_collector.visit(node) return info_collector def _get_function_definition(self): @@ -755,7 +755,7 @@ def _insert_globals(self, unindented_body): def _get_globals_in_body(unindented_body): node = _parse_text(unindented_body) visitor = _GlobalFinder() - ast.walk(node, visitor) + visitor.visit(node) return visitor.globals_ @@ -777,7 +777,7 @@ def get_checks(self): return {} -class _FunctionInformationCollector: +class _FunctionInformationCollector(ast.RopeNodeVisitor): def __init__(self, start, end, is_global): self.start = start self.end = end @@ -822,12 +822,12 @@ def _FunctionDef(self, node): for name in _get_argnames(node.args): self._written_variable(name, node.lineno) for child in node.body: - ast.walk(child, self) + self.visit(child) else: self._written_variable(node.name, node.lineno) visitor = _VariableReadsAndWritesFinder() for child in node.body: - ast.walk(child, visitor) + visitor.visit(child) for name in visitor.read - visitor.written: self._read_variable(name, node.lineno) @@ -846,21 +846,21 @@ def _Name(self, node): def _MatchAs(self, node): self._written_variable(node.name, node.lineno) if node.pattern: - ast.walk(node.pattern, self) + self.visit(node.pattern) def _Assign(self, node): - ast.walk(node.value, self) + self.visit(node.value) for child in node.targets: - ast.walk(child, self) + self.visit(child) def _AugAssign(self, node): - ast.walk(node.value, self) + self.visit(node.value) if isinstance(node.target, ast.Name): target_id = node.target.id self._read_variable(target_id, node.target.lineno) self._written_variable(target_id, node.target.lineno) else: - ast.walk(node.target, self) + self.visit(node.target) def _ClassDef(self, node): self._written_variable(node.name, node.lineno) @@ -882,8 +882,8 @@ def _comp_exp(self, node): written = OrderedSet(self.written) maybe_written = OrderedSet(self.maybe_written) - for child in ast.get_child_nodes(node): - ast.walk(child, self) + for child in ast.iter_child_nodes(node): + self.visit(child) comp_names = list( chain.from_iterable( @@ -914,18 +914,18 @@ def _While(self, node): def _For(self, node): with self._handle_loop_context(node), self._handle_conditional_context(node): # iter has to be checked before the target variables - ast.walk(node.iter, self) - ast.walk(node.target, self) + self.visit(node.iter) + self.visit(node.target) for child in node.body: - ast.walk(child, self) + self.visit(child) for child in node.orelse: - ast.walk(child, self) + self.visit(child) def _handle_conditional_node(self, node): with self._handle_conditional_context(node): - for child in ast.get_child_nodes(node): - ast.walk(child, self) + for child in ast.iter_child_nodes(node): + self.visit(child) @contextmanager def _handle_conditional_context(self, node): @@ -955,7 +955,7 @@ def _get_argnames(arguments): return result -class _VariableReadsAndWritesFinder: +class _VariableReadsAndWritesFinder(ast.RopeNodeVisitor): def __init__(self): self.written = set() self.read = set() @@ -969,8 +969,8 @@ def _Name(self, node): def _FunctionDef(self, node): self.written.add(node.name) visitor = _VariableReadsAndWritesFinder() - for child in ast.get_child_nodes(node): - ast.walk(child, visitor) + for child in ast.iter_child_nodes(node): + visitor.visit(child) self.read.update(visitor.read - visitor.written) def _Class(self, node): @@ -982,7 +982,7 @@ def find_reads_and_writes(code): return set(), set() node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() - ast.walk(node, visitor) + visitor.visit(node) return visitor.read, visitor.written @staticmethod @@ -991,18 +991,18 @@ def find_reads_for_one_liners(code): return set(), set() node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() - ast.walk(node, visitor) + visitor.visit(node) return visitor.read -class _BaseErrorFinder: +class _BaseErrorFinder(ast.RopeNodeVisitor): @classmethod def has_errors(cls, code): if code.strip() == "": return False node = _parse_text(code) visitor = cls() - ast.walk(node, visitor) + visitor.visit(node) return visitor.error @@ -1020,14 +1020,14 @@ def _While(self, node): def loop_encountered(self, node): self.loop_count += 1 for child in node.body: - ast.walk(child, self) + self.visit(child) self.loop_count -= 1 if node.orelse: if isinstance(node.orelse, (list, tuple)): for node_ in node.orelse: - ast.walk(node_, self) + self.visit(node_) else: - ast.walk(node.orelse, self) + self.visit(node.orelse) def _Break(self, node): self.check_loop() @@ -1063,7 +1063,7 @@ def _ClassDef(self, node): pass -class _GlobalFinder: +class _GlobalFinder(ast.RopeNodeVisitor): def __init__(self): self.globals_ = OrderedSet() diff --git a/rope/refactor/importutils/module_imports.py b/rope/refactor/importutils/module_imports.py index 39cb718f6..e448528c0 100644 --- a/rope/refactor/importutils/module_imports.py +++ b/rope/refactor/importutils/module_imports.py @@ -26,7 +26,7 @@ def imports(self): def _get_unbound_names(self, defined_pyobject): visitor = _GlobalUnboundNameFinder(self.pymodule, defined_pyobject) - ast.walk(self.pymodule.get_ast(), visitor) + visitor.visit(self.pymodule.get_ast()) return visitor.unbound def _get_all_star_list(self, pymodule): @@ -415,7 +415,7 @@ def _can_name_be_added(self, imported_primary): return False -class _UnboundNameFinder: +class _UnboundNameFinder(ast.RopeNodeVisitor): def __init__(self, pyobject): self.pyobject = pyobject @@ -427,8 +427,8 @@ def _visit_child_scope(self, node): .pyobject ) visitor = _LocalUnboundNameFinder(pyobject, self) - for child in ast.get_child_nodes(node): - ast.walk(child, visitor) + for child in ast.iter_child_nodes(node): + visitor.visit(child) def _FunctionDef(self, node): self._visit_child_scope(node) @@ -453,7 +453,7 @@ def _Attribute(self, node): ): self.add_unbound(primary) else: - ast.walk(node, self) + self.visit(node) def _get_root(self): pass diff --git a/rope/refactor/patchedast.py b/rope/refactor/patchedast.py index c545f1a6f..74bb6024b 100644 --- a/rope/refactor/patchedast.py +++ b/rope/refactor/patchedast.py @@ -90,7 +90,7 @@ def __call__(self, node): ) node.region = (self.source.offset, self.source.offset) if self.children: - node.sorted_children = ast.get_children(node) + node.sorted_children = [child for field, child in ast.iter_fields(node)] def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): if hasattr(node, "region"): diff --git a/rope/refactor/restructure.py b/rope/refactor/restructure.py index ac46223f8..68694cdbf 100644 --- a/rope/refactor/restructure.py +++ b/rope/refactor/restructure.py @@ -310,7 +310,7 @@ def _auto_indent(self, offset, text): def _get_nearest_roots(self, node): if node not in self._nearest_roots: result = [] - for child in ast.get_child_nodes(node): + for child in ast.iter_child_nodes(node): if child in self.matched_asts: result.append(child) else: diff --git a/rope/refactor/similarfinder.py b/rope/refactor/similarfinder.py index 5bc042587..2a672802d 100644 --- a/rope/refactor/similarfinder.py +++ b/rope/refactor/similarfinder.py @@ -2,8 +2,7 @@ import re import rope.refactor.wildcards -from rope.base import libutils -from rope.base import codeanalyze, exceptions, ast, builtins +from rope.base import libutils, codeanalyze, exceptions, ast, builtins from rope.refactor import patchedast, wildcards from rope.refactor.patchedast import MismatchedTokenError @@ -169,7 +168,7 @@ def _check_expression(self, node): self.matches.append(ExpressionMatch(node, mapping)) def _check_statements(self, node): - for child in ast.get_children(node): + for field, child in ast.iter_fields(node): if isinstance(child, (list, tuple)): self.__check_stmt_list(child) @@ -211,8 +210,11 @@ def _match_nodes(self, expected, node, mapping): def _get_children(self, node): """Return not `ast.expr_context` children of `node`""" - children = ast.get_children(node) - return [child for child in children if not isinstance(child, ast.expr_context)] + return [ + child + for field, child in ast.iter_fields(node) + if not isinstance(child, ast.expr_context) + ] def _match_stmts(self, current_stmts, mapping): if len(current_stmts) != len(self.pattern): diff --git a/rope/refactor/suites.py b/rope/refactor/suites.py index 6e7eb454e..133652b20 100644 --- a/rope/refactor/suites.py +++ b/rope/refactor/suites.py @@ -71,7 +71,7 @@ def get_children(self): if self._children is None: walker = _SuiteWalker(self) for child in self.child_nodes: - ast.walk(child, walker) + walker.visit(child) self._children = walker.suites return self._children @@ -98,7 +98,7 @@ def _get_level(self): return self.parent._get_level() + 1 -class _SuiteWalker: +class _SuiteWalker(ast.RopeNodeVisitor): def __init__(self, suite): self.suite = suite self.suites = [] diff --git a/rope/refactor/usefunction.py b/rope/refactor/usefunction.py index 6a9fdd200..656175abb 100644 --- a/rope/refactor/usefunction.py +++ b/rope/refactor/usefunction.py @@ -173,7 +173,7 @@ def _named_expr_count(node): return visitor.named_expression -class _ReturnOrYieldFinder: +class _ReturnOrYieldFinder(ast.RopeNodeVisitor): def __init__(self): self.returns = 0 self.named_expression = 0 @@ -197,6 +197,6 @@ def _ClassDef(self, node): def start_walking(self, node): nodes = [node] if isinstance(node, ast.FunctionDef): - nodes = ast.get_child_nodes(node) + nodes = list(ast.iter_child_nodes(node)) for child in nodes: - ast.walk(child, self) + self.visit(child)