diff --git a/CHANGES b/CHANGES index 946870089e3..d60b832e029 100644 --- a/CHANGES +++ b/CHANGES @@ -29,6 +29,8 @@ Features added down the build * #6837: LaTeX: Support a nested table * #6966: graphviz: Support ``:class:`` option +* #2755: autodoc: Support type_comment style (ex. ``# type: (str) -> str``) + annotation (python3.8+ or typed_ast is required) Bugs fixed ---------- diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index 0b55a92bd2e..ecbd0b3ce20 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -8,6 +8,7 @@ :license: BSD, see LICENSE for details. """ +import ast import builtins import enum import inspect @@ -20,7 +21,7 @@ isclass, ismethod, ismethoddescriptor, isroutine ) from io import StringIO -from typing import Any, Callable, Mapping, List, Tuple +from typing import Any, Callable, Dict, Generator, Mapping, List, Tuple, Union from sphinx.deprecation import RemovedInSphinx30Warning from sphinx.util import logging @@ -40,6 +41,7 @@ logger = logging.getLogger(__name__) memory_address_re = re.compile(r' at 0x[0-9a-f]{8,16}(?=>)', re.IGNORECASE) +type_comment_re = re.compile(r'\s*\((.*)\)\s* -> \s*(.*)\s*') # Copied from the definition of inspect.getfullargspec from Python master, @@ -315,6 +317,76 @@ def is_builtin_class_method(obj: Any, attr_name: str) -> bool: return getattr(builtins, safe_getattr(cls, '__name__', '')) is cls +def iter_args(func: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> Generator[str, None, None]: + """Get an iterator for arguments names from FunctionDef node.""" + if hasattr(func.args, "posonlyargs"): # py38 or above + yield from (a.arg for a in func.args.posonlyargs) # type: ignore + yield from (a.arg for a in func.args.args) + if func.args.vararg: + yield func.args.vararg.arg + if func.args.kwarg: + yield func.args.kwarg.arg + + +def parse_argtypes(s: str) -> Generator[str, None, None]: + """Parse argument part of type_comment.""" + start = 0 + parens = 0 + for i, char in enumerate(s): + if char == '[': + parens += 1 + elif char == ']': + parens -= 1 + elif char == ',' and parens == 0: + yield s[start:i] + start = i + 1 + + yield s[start:] + + +def get_type_hints_from_type_comment(obj: Any) -> Dict[str, str]: + """Get type hints from py2 style type_comment. + + Python3.8+ or typed_ast is required. + """ + if sys.version_info > (3, 8): + parse = partial(ast.parse, type_comments=True) + else: + try: + from typed_ast import ast3 + parse = ast3.parse + except ImportError: + return {} + + try: + source = inspect.getsource(obj) + if source.startswith((' ', r'\t')): + # subject is placed inside class or block. To read its docstring, + # this adds if-block before the declaration. + module = parse('if True:\n' + source) + subject = module.body[0].body[0] # type: ignore + else: + module = parse(source) + subject = module.body[0] # type: ignore + type_comment = subject.type_comment + + if type_comment is None: # no type_comment + return {} + else: + type_hints = {} # type: Dict[str, Any] + argtypes, rtype = type_comment_re.match(type_comment).groups() + type_hints['return'] = rtype.strip() + + if argtypes.strip() != '...': + for name, typ in zip(iter_args(subject), parse_argtypes(argtypes)): + type_hints[name] = typ.strip() + + return type_hints + except (OSError, TypeError): # failed to load source code + return {} + + class Parameter: """Fake parameter class for python2.""" POSITIONAL_ONLY = 0 @@ -372,6 +444,14 @@ def __init__(self, subject: Callable, bound_method: bool = False, # we try to build annotations from argspec. self.annotations = {} + # merge type_comment-based type hints + self.annotations.update(get_type_hints_from_type_comment(subject)) + for param in self.parameters.values(): + if param.annotation is param.empty and param.name in self.annotations: + param._annotation = self.annotations[param.name] + if self.return_annotation is inspect.Parameter.empty and 'return' in self.annotations: + self.signature._return_annotation = self.annotations['return'] # type: ignore + if bound_method: # client gives a hint that the subject is a bound method diff --git a/tests/test_util_inspect.py b/tests/test_util_inspect.py index 2f463196556..151d46ece50 100644 --- a/tests/test_util_inspect.py +++ b/tests/test_util_inspect.py @@ -195,7 +195,8 @@ def meth2(self, arg1, arg2): def test_Signature_annotations(): from typing_test_data import (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, - f11, f12, f13, f14, f15, f16, f17, f18, f19, Node) + f11, f12, f13, f14, f15, f16, f17, f18, f19, + f20, Node) # Class annotations sig = inspect.Signature(f0).format_args() @@ -279,6 +280,11 @@ def test_Signature_annotations(): sig = inspect.Signature(f19).format_args() assert sig == '(*args: int, **kwargs: str)' + # annotations by type_comment + sig = inspect.Signature(f20).format_args() + assert sig == ('(arg1: str, arg2: List[int], arg3: Tuple[int, Union[str, int]] = None, ' + '*args: str, **kwargs: str) -> int') + # type hints by string sig = inspect.Signature(Node.children).format_args() if (3, 5, 0) <= sys.version_info < (3, 5, 3): diff --git a/tests/typing_test_data.py b/tests/typing_test_data.py index 76db7c898c0..63c3b927a53 100644 --- a/tests/typing_test_data.py +++ b/tests/typing_test_data.py @@ -96,6 +96,10 @@ def f19(*args: int, **kwargs: str): pass +def f20(arg1, arg2, arg3=None, *args, **kwargs): + # type: (str, List[int], Tuple[int, Union[str, int]], str, str) -> int + pass + class Node: def __init__(self, parent: Optional['Node']) -> None: