diff --git a/asttokens/astroid_compat.py b/asttokens/astroid_compat.py new file mode 100644 index 0000000..3ba5e8d --- /dev/null +++ b/asttokens/astroid_compat.py @@ -0,0 +1,14 @@ +try: + from astroid import nodes as astroid_node_classes + + # astroid_node_classes should be whichever module has the NodeNG class + from astroid.nodes import NodeNG +except Exception: + try: + from astroid import node_classes as astroid_node_classes + from astroid.node_classes import NodeNG + except Exception: # pragma: no cover + astroid_node_classes = None + NodeNG = None + +__all__ = ["astroid_node_classes", "NodeNG"] diff --git a/asttokens/mark_tokens.py b/asttokens/mark_tokens.py index 0f935c0..0aa497f 100644 --- a/asttokens/mark_tokens.py +++ b/asttokens/mark_tokens.py @@ -24,12 +24,7 @@ from . import util from .asttokens import ASTTokens from .util import AstConstant - -try: - import astroid.node_classes as nc -except Exception: - # This is only used for type checking, we don't need it if astroid isn't installed. - nc = None +from .astroid_compat import astroid_node_classes as nc if TYPE_CHECKING: from .util import AstNode @@ -88,6 +83,9 @@ def _visit_after_children(self, node, parent_token, token): first = token last = None for child in cast(Callable, self._iter_children)(node): + # astroid slices have especially wrong positions, we don't want them to corrupt their parents. + if util.is_empty_astroid_slice(child): + continue if not first or child.first_token.index < first.index: first = child.first_token if not last or child.last_token.index > last.index: diff --git a/asttokens/util.py b/asttokens/util.py index 96fa931..4abc83e 100644 --- a/asttokens/util.py +++ b/asttokens/util.py @@ -24,8 +24,9 @@ from six import iteritems + if TYPE_CHECKING: # pragma: no cover - from astroid.node_classes import NodeNG + from .astroid_compat import NodeNG # Type class used to expand out the definition of AST to include fields added by this library # It's not actually used for anything other than type checking though! @@ -218,6 +219,15 @@ def is_slice(node): ) +def is_empty_astroid_slice(node): + # type: (AstNode) -> bool + return ( + node.__class__.__name__ == "Slice" + and not isinstance(node, ast.AST) + and node.lower is node.upper is node.step is None + ) + + # Sentinel value used by visit_tree(). _PREVISIT = object() diff --git a/pyproject.toml b/pyproject.toml index ea6e65f..2543e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,5 +20,5 @@ disallow_untyped_calls=false ignore_missing_imports=true [[tool.mypy.overrides]] -module = ["astroid", "astroid.node_classes"] -ignore_missing_imports = true \ No newline at end of file +module = ["astroid", "astroid.node_classes", "astroid.nodes", "astroid.nodes.utils"] +ignore_missing_imports = true diff --git a/setup.cfg b/setup.cfg index a2bde84..f506500 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ install_requires = setup_requires = setuptools>=44; setuptools_scm[toml]>=3.4.3 [options.extras_require] -test = astroid<=2.5.3; pytest +test = astroid; pytest [options.package_data] asttokens = py.typed diff --git a/tests/test_astroid.py b/tests/test_astroid.py index 1608359..a5cc6d7 100644 --- a/tests/test_astroid.py +++ b/tests/test_astroid.py @@ -2,9 +2,9 @@ from __future__ import unicode_literals, print_function import astroid -from astroid.node_classes import NodeNG from asttokens import ASTTokens +from asttokens.astroid_compat import astroid_node_classes from . import test_mark_tokens @@ -13,7 +13,7 @@ class TestAstroid(test_mark_tokens.TestMarkTokens): is_astroid_test = True module = astroid - nodes_classes = NodeNG + nodes_classes = astroid_node_classes.NodeNG context_classes = [ (astroid.Name, astroid.DelName, astroid.AssignName), (astroid.Attribute, astroid.DelAttr, astroid.AssignAttr), diff --git a/tests/test_mark_tokens.py b/tests/test_mark_tokens.py index cebb226..5aba077 100644 --- a/tests/test_mark_tokens.py +++ b/tests/test_mark_tokens.py @@ -19,6 +19,11 @@ from . import tools +try: + from astroid.nodes.utils import Position as AstroidPosition +except Exception: + AstroidPosition = () + class TestMarkTokens(unittest.TestCase): maxDiff = None @@ -230,7 +235,7 @@ def test_deep_recursion(self): def test_slices(self): # Make sure we don't fail on parsing slices of the form `foo[4:]`. - source = "(foo.Area_Code, str(foo.Phone)[:3], str(foo.Phone)[3:], foo[:], bar[::2, :], [a[:]][::-1])" + source = "(foo.Area_Code, str(foo.Phone)[:3], str(foo.Phone)[3:], foo[:], bar[::2, :], bar2[:, ::2], [a[:]][::-1])" m = self.create_mark_checker(source) self.assertIn("Tuple:" + source, m.view_nodes_at(1, 0)) self.assertEqual(m.view_nodes_at(1, 1), @@ -243,7 +248,7 @@ def test_slices(self): # important, so we skip them here. self.assertEqual({n for n in m.view_nodes_at(1, 56) if 'Slice:' not in n}, { "Subscript:foo[:]", "Name:foo" }) - self.assertEqual({n for n in m.view_nodes_at(1, 64) if 'Slice:' not in n}, + self.assertEqual({n for n in m.view_nodes_at(1, 64) if 'Slice:' not in n and 'Tuple:' not in n}, { "Subscript:bar[::2, :]", "Name:bar" }) def test_adjacent_strings(self): @@ -814,6 +819,10 @@ def assert_nodes_equal(self, t1, t2): else: self.assertEqual(type(t1), type(t2)) + if isinstance(t1, AstroidPosition): + # Ignore the lineno/col_offset etc. from astroid + return + if isinstance(t1, (list, tuple)): self.assertEqual(len(t1), len(t2)) for vc1, vc2 in zip(t1, t2):