diff --git a/docs/src/release_notes.rst b/docs/src/release_notes.rst index b6fbff5..e32f4e0 100644 --- a/docs/src/release_notes.rst +++ b/docs/src/release_notes.rst @@ -15,6 +15,7 @@ Unreleased :confval:`codeautolink_warn_on_failed_resolve` for debugging (:issue:`106`) - Define extension environment version for Sphinx (:issue:`107`) - Merge environments only when the extension is active (:issue:`107`) +- Link arguments and annotated assignment with type hints (:issue:`108`) 0.10.0 (2022-01-25) ------------------- diff --git a/src/sphinx_codeautolink/parse.py b/src/sphinx_codeautolink/parse.py index 201c1c9..a6360fa 100644 --- a/src/sphinx_codeautolink/parse.py +++ b/src/sphinx_codeautolink/parse.py @@ -21,27 +21,6 @@ def parse_names(source: str, doctree_node) -> List['Name']: return sum([split_access(a) for a in visitor.accessed], []) -@dataclass -class PendingAccess: - """Pending name access.""" - - components: List[ast.AST] - - -@dataclass -class PendingAssign: - """ - Pending assign target. - - `targets` represent the assignment targets. - If a single PendingAccess is found, it should be used to store the value - on the right hand side of the assignment. If multiple values are found, - they should overwrite any names in the current scope and not assign values. - """ - - targets: Union[Optional[PendingAccess], List[Optional[PendingAccess]]] - - @dataclass class Component: """Name access component.""" @@ -61,6 +40,8 @@ def from_ast(cls, node): elif isinstance(node, ast.Attribute): name = node.attr context = node.ctx.__class__.__name__.lower() + elif isinstance(node, ast.arg): + name = node.arg elif isinstance(node, ast.Call): name = NameBreak.call else: @@ -69,6 +50,27 @@ def from_ast(cls, node): return cls(name, node.lineno, end_lineno, context) +@dataclass +class PendingAccess: + """Pending name access.""" + + components: List[Component] + + +@dataclass +class PendingAssign: + """ + Pending assign target. + + `targets` represent the assignment targets. + If a single PendingAccess is found, it should be used to store the value + on the right hand side of the assignment. If multiple values are found, + they should overwrite any names in the current scope and not assign values. + """ + + targets: Union[Optional[PendingAccess], List[Optional[PendingAccess]]] + + class NameBreak(str, Enum): """Elements that break name access chains.""" @@ -269,7 +271,7 @@ def _assign(self, local_name: str, components: List[Component]): self.pseudo_scopes_stack[-1][local_name] = components def _access(self, access: PendingAccess) -> Optional[Access]: - components = [Component.from_ast(n) for n in access.components] + components = access.components prior = self.pseudo_scopes_stack[-1].get(components[0].name, None) if prior is None: @@ -306,7 +308,7 @@ def _resolve_assignment(self, assignment: Assignment): continue if len(target.components) == 1: - comp = Component.from_ast(target.components[0]) + comp = target.components[0] self._overwrite(comp.name) if access is not None: self._assign(comp.name, access.full_components) @@ -408,16 +410,16 @@ def visit_ImportFrom(self, node: ast.ImportFrom): self.visit_Import(node, prefix=node.module + '.') @track_parents - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): """Visit a Name node.""" - return PendingAccess([node]) + return PendingAccess([Component.from_ast(node)]) @track_parents - def visit_Attribute(self, node): + def visit_Attribute(self, node: ast.Attribute): """Visit an Attribute node.""" inner: Optional[PendingAccess] = self.visit(node.value) if inner is not None: - inner.components.append(node) + inner.components.append(Component.from_ast(node)) return inner @track_parents @@ -425,7 +427,7 @@ def visit_Call(self, node: ast.Call): """Visit a Call node.""" inner: Optional[PendingAccess] = self.visit(node.func) if inner is not None: - inner.components.append(node) + inner.components.append(Component.from_ast(node)) with self.reset_parents(): for arg in node.args + node.keywords: self.visit(arg) @@ -462,14 +464,22 @@ def visit_Assign(self, node: ast.Assign): @track_parents def visit_AnnAssign(self, node: ast.AnnAssign): """Visit an AnnAssign node.""" - if node.value is not None: - value = self.visit(node.value) - target = self.visit(node.target) - - with self.reset_parents(): - self.visit(node.annotation) + value = self.visit(node.value) if node.value is not None else None + annot = self.visit(node.annotation) + if annot is not None: + if value is not None: + self._access(value) + + annot.components.append(Component( + NameBreak.call, + node.annotation.lineno, + node.annotation.end_lineno, + 'load', + )) + value = annot - if node.value is not None: + target = self.visit(node.target) + if value is not None: return Assignment([PendingAssign(target)], value) def visit_AugAssign(self, node: ast.AugAssign): @@ -528,31 +538,41 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): self._overwrite(node.name) for dec in node.decorator_list: self.visit(dec) - if node.returns is not None: - self.visit(node.returns) for d in node.args.defaults + node.args.kw_defaults: if d is None: continue self.visit(d) args = self._get_args(node.args) args += [node.args.vararg, node.args.kwarg] - for arg in args: - if arg is None or arg.annotation is None: - continue - self.visit(arg.annotation) inner = self.__class__(self.doctree_node) inner.pseudo_scopes_stack[0] = self.pseudo_scopes_stack[0].copy() inner.outer_scopes_stack = list(self.outer_scopes_stack) inner.outer_scopes_stack.append(self.pseudo_scopes_stack[0]) + for arg in args: if arg is None: continue - inner._overwrite(arg.arg) + inner.visit(arg) + if node.returns is not None: + self.visit(node.returns) for n in node.body: inner.visit(n) self.accessed.extend(inner.accessed) + @track_parents + def visit_arg(self, arg: ast.arg): + """Handle function argument and its annotation.""" + target = PendingAccess([Component.from_ast(arg)]) + if arg.annotation is not None: + value = self.visit(arg.annotation) + if value is not None: + comp = Component(NameBreak.call, arg.lineno, arg.end_lineno, 'load') + value.components.append(comp) + else: + value = None + return Assignment([PendingAssign(target)], value) + def visit_Lambda(self, node: ast.Lambda): """Swap node order and separate inner scope.""" for d in node.args.defaults + node.args.kw_defaults: diff --git a/tests/parse/__init__.py b/tests/parse/__init__.py index 38ac8c7..c3b4eca 100644 --- a/tests/parse/__init__.py +++ b/tests/parse/__init__.py @@ -61,6 +61,24 @@ def test_simple_import_then_access(self): refs = [('lib', 'lib'), ('lib', 'lib')] return s, refs + @refs_equal + def test_inside_list_literal(self): + s = 'import lib\n[lib]' + refs = [('lib', 'lib'), ('lib', 'lib')] + return s, refs + + @refs_equal + def test_inside_subscript(self): + s = 'import lib\n0[lib]' + refs = [('lib', 'lib'), ('lib', 'lib')] + return s, refs + + @refs_equal + def test_outside_subscript(self): + s = 'import lib\nlib[0]' + refs = [('lib', 'lib'), ('lib', 'lib')] + return s, refs + @refs_equal def test_simple_import_then_attrib(self): s = 'import lib\nlib.attr' diff --git a/tests/parse/_util.py b/tests/parse/_util.py index 0992e2d..0366902 100644 --- a/tests/parse/_util.py +++ b/tests/parse/_util.py @@ -8,6 +8,8 @@ def wrapper(self): source, expected = func(self) names = parse_names(source, doctree_node=None) names = sorted(names, key=lambda name: name.lineno) + print('All names:') + [print(n) for n in names] for n, e in zip(names, expected): s = '.'.join(c for c in n.import_components) assert s == e[0], f'Wrong import! Expected\n{e}\ngot\n{n}' diff --git a/tests/parse/assign.py b/tests/parse/assign.py index 93bd2aa..8ff2a9d 100644 --- a/tests/parse/assign.py +++ b/tests/parse/assign.py @@ -78,27 +78,42 @@ def test_augassign_uses_and_assigns_imported(self): return s, refs @refs_equal - def test_annassign_uses_imported(self): + def test_annassign_overwrites_imported(self): s = 'import a\na: b = 1\na' refs = [('a', 'a')] return s, refs @refs_equal def test_annassign_uses_and_assigns_imported(self): - s = 'import a\na: b = a\na' - refs = [('a', 'a'), ('a', 'a'), ('a', 'a')] + s = 'import a\nb: 1 = a\nb.c' + refs = [('a', 'a'), ('a', 'a'), ('a.c', 'b.c')] + return s, refs + + @refs_equal + def test_annassign_uses_and_annotates_imported(self): + s = 'import a\nb: a = 1\nb.c' + refs = [('a', 'a'), ('a', 'a'), ('a.().c', 'b.c')] + return s, refs + + @refs_equal + def test_annassign_prioritises_annotation(self): + s = 'import a, b\nc: a = b\nc.d' + # note that AnnAssign is executed from value -> annot -> target + refs = [('a', 'a'), ('b', 'b'), ('b', 'b'), ('a', 'a'), ('a.().d', 'c.d')] return s, refs @refs_equal def test_annassign_why_would_anyone_do_this(self): - s = 'import a\na: a = a\na' - refs = [('a', 'a'), ('a', 'a'), ('a', 'a'), ('a', 'a')] + s = 'import a\na: a = a\na.b' + refs = [('a', 'a'), ('a', 'a'), ('a', 'a'), ('a.().b', 'a.b')] return s, refs @refs_equal def test_annassign_without_value_overrides_annotation_but_not_linked(self): + # note that this is different from runtime behavior + # which does not overwrite the variable value s = 'import a\na: b\na' - refs = [('a', 'a'), ('a', 'a')] + refs = [('a', 'a')] return s, refs @pytest.mark.skipif( diff --git a/tests/parse/scope.py b/tests/parse/scope.py index e6a702c..bd73239 100644 --- a/tests/parse/scope.py +++ b/tests/parse/scope.py @@ -42,11 +42,38 @@ def test_func_assigns_then_used_outside(self): return s, refs @refs_equal - def test_func_annotations_then_assigns(self): + def test_func_annotates_then_uses(self): + s = 'import a\ndef f(arg: a):\n arg.b' + refs = [('a', 'a'), ('a', 'a'), ('a.().b', 'arg.b')] + return s, refs + + @refs_equal + def test_func_annotates_then_assigns(self): s = 'import a\ndef f(arg: a) -> a:\n a = 1' refs = [('a', 'a'), ('a', 'a'), ('a', 'a')] return s, refs + @refs_equal + def test_func_annotates_as_generic_then_uses(self): + s = 'import a\ndef f(arg: a[0]):\n arg.b' + refs = [('a', 'a'), ('a', 'a')] + return s, refs + + @refs_equal + def test_func_annotates_inside_generic_then_uses(self): + s = 'import a\ndef f(arg: b[a]):\n arg.b' + refs = [('a', 'a'), ('a', 'a')] + return s, refs + + @pytest.mark.skipif( + sys.version_info < (3, 10), reason='Union syntax introduced in 3.10.' + ) + @refs_equal + def test_func_annotates_union_then_uses(self): + s = 'import a\ndef f(arg: a | 1):\n arg.b' + refs = [('a', 'a'), ('a', 'a')] + return s, refs + @refs_equal def test_func_kw_default_uses(self): s = 'import a\ndef f(*_, c, b=a):\n pass'