From 1db8d16c4fbd6dea2e0c17315121565f15e5b8d0 Mon Sep 17 00:00:00 2001 From: Neil Souza Date: Thu, 9 Mar 2023 15:01:15 -0800 Subject: [PATCH] Use blib2to3 parser from `black` package to support match statement (#80) * Use blib2to3 parser to support match statement The built-in lib2to3 does not support pattern matching (Python 3.10+): https://docs.python.org/3.11/library/2to3.html#module-lib2to3 The [black][] project managed to get some level of parsing support for `match` out of their modified version `blib2to3`, see: 1. https://github.com/psf/black/issues/2242 2. https://github.com/psf/black/pull/2586 [black]: https://github.com/psf/black This change adds `black` as a dependency and switches to using `blib2to3` to parse. Tests pass, but that's all that's been attempted thus far. * Add a _unreleased changelog for blib2to3 integration * fix mypy in docspec/src/docspec/__init__.py * fix mypy * update changelog format * update GitHub workflow * fix workflow * insert PR url * use `--no-venv-check` also for `slap run` in docs job --------- Co-authored-by: Niklas Rosenstein --- .github/workflows/python.yml | 14 +- docspec-python/.changelog/_unreleased.toml | 6 + docspec-python/pyproject.toml | 7 +- docspec-python/src/docspec_python/__init__.py | 12 +- docspec-python/src/docspec_python/parser.py | 373 +++++++++++++++--- docspec-python/test/test_parser.py | 26 ++ docspec/src/docspec/__init__.py | 4 +- 7 files changed, 365 insertions(+), 77 deletions(-) create mode 100644 docspec-python/.changelog/_unreleased.toml diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 08f54bf..7cdeead 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -17,9 +17,9 @@ jobs: python-version: ["3.7", "3.8", "3.9", "3.10", "3.x"] project: ["docspec", "docspec-python"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: NiklasRosenstein/slap@gha/install/v1 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 with: { python-version: "${{ matrix.python-version }}" } - run: slap install --only ${{ matrix.project }} --no-venv-check -v - run: DOCSPEC_TEST_NO_DEVELOP=true slap test ${{ matrix.project }} @@ -28,16 +28,16 @@ jobs: runs-on: ubuntu-latest if: github.event_name == 'pull_request' steps: - - uses: actions/checkout@v2 - - uses: NiklasRosenstein/slap@gha/changelog/update/v1 + - uses: actions/checkout@v3 + - uses: NiklasRosenstein/slap@gha/changelog/update/v2 docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: NiklasRosenstein/slap@gha/install/v1 - - run: slap install --only-extras docs --no-venv-check - - run: slap run docs:build + - run: slap install --no-venv-check --only-extras docs + - run: slap run --no-venv-check docs:build - uses: JamesIves/github-pages-deploy-action@4.1.4 if: github.ref == 'refs/heads/develop' with: { branch: gh-pages, folder: docs/_site, ssh-key: "${{ secrets.DEPLOY_KEY }}" } diff --git a/docspec-python/.changelog/_unreleased.toml b/docspec-python/.changelog/_unreleased.toml new file mode 100644 index 0000000..f795770 --- /dev/null +++ b/docspec-python/.changelog/_unreleased.toml @@ -0,0 +1,6 @@ +[[entries]] +id = "8628524b-3376-45db-a676-240b00c20d08" +type = "fix" +description = "Swap in `blib2to3` parser (bundled with the `black` package) for the stdlib `lib2to3` module in order to support `match` statements (PEP 634 - Structural Pattern Matching)." +author = "@nrser" +pr = "https://github.com/NiklasRosenstein/docspec/pull/80" diff --git a/docspec-python/pyproject.toml b/docspec-python/pyproject.toml index e8544df..2a8ac85 100644 --- a/docspec-python/pyproject.toml +++ b/docspec-python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "docspec-python" -version = "2.0.2" +version = "2.0.2+blib2to3" description = "A parser based on lib2to3 producing docspec data from Python source code." authors = ["Niklas Rosenstein "] license = "MIT" @@ -12,6 +12,7 @@ packages = [{ include = "docspec_python", from="src" }] python = "^3.7" docspec = "^2.0.2" "nr.util" = ">=0.7.0" +black = "^23.1.0" [tool.poetry.dev-dependencies] mypy = "*" @@ -28,6 +29,10 @@ typed = true pytest = "pytest test/ -vv" mypy = "mypy src/ test/ --check-untyped-defs" +[[tool.mypy.overrides]] +module = "blib2to3.*" +ignore_missing_imports = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/docspec-python/src/docspec_python/__init__.py b/docspec-python/src/docspec_python/__init__.py index 65717c6..15f4417 100644 --- a/docspec-python/src/docspec_python/__init__.py +++ b/docspec-python/src/docspec_python/__init__.py @@ -46,10 +46,10 @@ def load_python_modules( - modules: t.Sequence[str] = None, - packages: t.Sequence[str] = None, - search_path: t.Sequence[t.Union[str, Path]] = None, - options: ParserOptions = None, + modules: t.Optional[t.Sequence[str]] = None, + packages: t.Optional[t.Sequence[str]] = None, + search_path: t.Optional[t.Sequence[t.Union[str, Path]]] = None, + options: t.Optional[ParserOptions] = None, raise_: bool = True, encoding: t.Optional[str] = None, ) -> t.Iterable[Module]: @@ -133,7 +133,7 @@ def parse_python_module( # type: ignore return parser.parse(ast, filename, module_name) -def find_module(module_name: str, search_path: t.Sequence[t.Union[str, Path]] = None) -> str: +def find_module(module_name: str, search_path: t.Optional[t.Sequence[t.Union[str, Path]]] = None) -> str: """ Finds the filename of a module that can be parsed with #parse_python_module(). If *search_path* is not set, the default #sys.path is used to search for the module. If *module_name* is a Python package, it will return the path to the package's `__init__.py` file. If the module does not exist, an #ImportError is raised. This is also @@ -165,7 +165,7 @@ def find_module(module_name: str, search_path: t.Sequence[t.Union[str, Path]] = def iter_package_files( package_name: str, - search_path: t.Sequence[t.Union[str, Path]] = None, + search_path: t.Optional[t.Sequence[t.Union[str, Path]]] = None, ) -> t.Iterable[t.Tuple[str, str]]: """ Returns an iterator for the Python source files in the specified package. The items returned by the iterator are tuples of the module name and filename. Supports a PEP 420 namespace package diff --git a/docspec-python/src/docspec_python/parser.py b/docspec-python/src/docspec_python/parser.py index 839e9eb..8b65315 100644 --- a/docspec-python/src/docspec_python/parser.py +++ b/docspec-python/src/docspec_python/parser.py @@ -23,11 +23,18 @@ Note: The `docspec_python.parser` module is not public API. """ +from __future__ import annotations + import dataclasses +from io import StringIO import os import re +import sys import textwrap import typing as t +import logging + +from typing_extensions import TypeGuard from nr.util.iter import SequenceWalker @@ -42,11 +49,58 @@ Location, Module, _ModuleMembers) -from lib2to3.refactor import RefactoringTool # type: ignore -from lib2to3.pgen2 import token -from lib2to3.pgen2.parse import ParseError -from lib2to3.pygram import python_symbols as syms -from lib2to3.pytree import Leaf, Node +from black.parsing import lib2to3_parse +from blib2to3.pgen2 import token +import blib2to3.pgen2.parse +from blib2to3.pygram import python_symbols as syms +from blib2to3.pytree import Leaf, Node, Context, NL, type_repr + +#: Logger for debugging. Slap it in when and where needed. +#: +#: Note to self and others, you can get debug log output with something like +#: +#: do +#: name: "debug-logging" +#: closure: { +#: precedes "copy-files" +#: } +#: action: { +#: logging.getLogger("").setLevel(logging.DEBUG) +#: } +#: +#: in your `build.novella` file. Be warned, it's a _lot_ of output, and lags the +#: build out considerably. +#: +_LOG = logging.getLogger(__name__) + +class ParseError(blib2to3.pgen2.parse.ParseError): + """Extends `blib2to3.pgen2.parse.ParseError` to add a `filename` attribute.""" + + msg: t.Text + type: t.Optional[int] + value: t.Optional[t.Text] + context: Context + filename: t.Text + + def __init__( + self, + msg: t.Text, + type: t.Optional[int], + value: t.Optional[t.Text], + context: Context, + filename: t.Text + ) -> None: + Exception.__init__( + self, "%s: type=%r, value=%r, context=%r, filename=%r" % ( + msg, type, value, context, filename + ) + ) + self.msg = msg + self.type = type + self.value = value + self.context = context + self.filename = filename + def dedent_docstring(s): @@ -55,23 +109,149 @@ def dedent_docstring(s): lines[1:] = textwrap.dedent('\n'.join(lines[1:])).split('\n') return '\n'.join(lines).strip() +T = t.TypeVar("T") +V = t.TypeVar("V") + +@t.overload +def find(predicate: t.Callable[[T], TypeGuard[V]], iterable: t.Iterable[T]) -> V | None: + ... +@t.overload +def find(predicate: t.Callable[[T], t.Any], iterable: t.Iterable[T]) -> T | None: + ... +@t.overload +def find(predicate: t.Callable[[T], t.Any], iterable: t.Iterable[T], as_type: type[V]) -> V | None: + ... -def find(predicate, iterable): + +def find(predicate, iterable, as_type=None): + """Basic find function, plus the ability to add an `as_type` argument and + receive a typed result (or raise `TypeError`). + + As you might expect, this is really only to make typing easier. + """ for item in iterable: if predicate(item): + if (as_type is not None) and (not isinstance(item, as_type)): + raise TypeError( + "expected predicate to only match type {}, matched {!r}".format( + as_type, + item, + ) + ) return item return None +@t.overload +def get(predicate: t.Callable[[T], object], iterable: t.Iterable[T]) -> T: + ... +@t.overload +def get(predicate: t.Callable[[T], object], iterable: t.Iterable[T], as_type: type[V]) -> V: + ... + +def get(predicate, iterable, as_type=None): + """Like `find`, but raises `ValueError` if `predicate` does not match. Assumes + that `None` means "no match", so don't try to use it to get `None` values in + `iterable`. + """ + if isinstance(as_type, type): + found = find(predicate, iterable, as_type) + else: + found = find(predicate, iterable) + + if found is None: + raise ValueError( + "item not found for predicate {!r} in iterable {!r}".format( + predicate, iterable + ) + ) + + return found + + +def is_node(x: object) -> TypeGuard[Node]: + """A simple `typing.TypeGuard` for `blib2to3.pytree.Node` instances. + + Useful because things like `lamda x: isinstance(x, Node)` seemingly _do not_ + infer their return type as `typing.TypeGuard[Node]`. + """ + return isinstance(x, Node) + + +def get_type_name(nl: NL) -> str: + """Get the "type name" for a `blib2to3.pytree.NL`, which is a `Node` or + `Leaf`. For display / debugging purposes. + """ + if isinstance(nl, Node): + return str(type_repr(nl.type)) + return str(token.tok_name.get(nl.type, nl.type)) + + +def pprint_nl( + nl: NL, + file: t.IO[str] = sys.stdout, + indent: int = 4, + _depth: int = 0 +) -> None: + """Pretty-print a `blib2to3.pytree.NL` over a bunch of lines, with indents, + to make it easier to read. Display / debugging use. + """ + assert nl.type is not None + + indent_s = (" " * indent * _depth) + + if nl.children: + print( + "{indent_s}{class_name}({type_name}, [".format( + indent_s=indent_s, + class_name=nl.__class__.__name__, + type_name=get_type_name(nl), + ), + file=file, + ) + for child in nl.children: + pprint_nl(child, file=file, _depth=_depth + 1) + print("{indent_s}])".format(indent_s=indent_s), file=file) + else: + print( + "{indent_s}{class_name}({type_name}, [])".format( + indent_s=indent_s, + class_name=nl.__class__.__name__, + type_name=get_type_name(nl), + ), + file=file, + ) + +def pformat_nl(nl: NL) -> str: + """Same as `pprint_nl`, but writes to a `str`. + """ + sio = StringIO() + pprint_nl(nl, file=sio) + return sio.getvalue() + +def get_value(node: NL) -> str: + if isinstance(node, Leaf): + return node.value + raise TypeError( + "expected node to have a `value` attribute (be a Leaf), given {!r}".format(node) + ) + @dataclasses.dataclass class ParserOptions: + # NOTE (@nrser) This is no longer used. It was passed to + # `lib2to3.refactor.RefactoringTool`, but that's been swapped out for + # `black.parsing.lib2to3_parse`, which does not take the same options. + # + # It looks like it supported Python 2.x code, and I don't see anything + # before 3.3 in `black.mode.TargetVersion`, so 2.x might be completely off + # the table when using the Black parser. print_function: bool = True treat_singleline_comment_blocks_as_docstrings: bool = False class Parser: - def __init__(self, options: ParserOptions = None) -> None: + def __init__(self, options: t.Optional[ParserOptions] = None) -> None: self.options = options or ParserOptions() def parse_to_ast(self, code, filename): @@ -79,15 +259,13 @@ def parse_to_ast(self, code, filename): Parses the string *code* to an AST with #lib2to3. """ - options = {'print_function': self.options.print_function} - try: # NOTE (@NiklasRosenstein): Adding newline at the end, a ParseError # could be raised without a trailing newline (tested in CPython 3.6 # and 3.7). - return RefactoringTool([], options).refactor_string(code + '\n', filename) + return lib2to3_parse(code + '\n') except ParseError as exc: - raise ParseError(exc.msg, exc.type, exc.value, tuple(exc.context) + (filename,)) + raise ParseError(exc.msg, exc.type, exc.value, exc.context, filename) def parse(self, ast, filename, module_name=None): self.filename = filename # pylint: disable=attribute-defined-outside-init @@ -116,7 +294,12 @@ def parse(self, ast, filename, module_name=None): module.sync_hierarchy() return module - def parse_declaration(self, parent, node, decorations=None) -> t.Union[None, _ModuleMembers, t.List[_ModuleMembers]]: + def parse_declaration( + self, + parent, + node, + decorations: t.Optional[list[Decoration]] = None + ) -> t.Union[None, _ModuleMembers, t.List[_ModuleMembers]]: if node.type == syms.simple_stmt: assert not decorations stmt = node.children[0] @@ -147,7 +330,7 @@ def parse_declaration(self, parent, node, decorations=None) -> t.Union[None, _Mo return self.parse_declaration(parent, node.children[1], decorations) return None - def _split_statement(self, stmt): + def _split_statement(self, stmt: Node) -> tuple[list[NL], list[NL], list[NL]]: """ Parses a statement node into three lists, consisting of the leaf nodes that are the name(s), annotation and value of the expression. The value @@ -155,7 +338,11 @@ def _split_statement(self, stmt): a plain expression). """ - def _parse(stack, current, stmt): + def _parse( + stack: list[tuple[str, list[NL]]], + current: tuple[str, list[NL]], + stmt: Node + ) -> list[tuple[str, list[NL]]]: for child in stmt.children: if not isinstance(child, Node) and child.value == '=': stack.append(current) @@ -170,7 +357,7 @@ def _parse(stack, current, stmt): stack.append(current) return stack - result = dict(_parse([], ('names', []), stmt)) + result: dict[str, list[NL]] = dict(_parse([], ('names', []), stmt)) return result.get('names', []), result.get('annotation', []), result.get('value', []) def parse_import(self, parent, node: Node) -> t.Iterable[Indirection]: @@ -222,37 +409,63 @@ def _from_import_to_indirection(prefix: str, node: t.Union[Node, Leaf]) -> Indir else: raise RuntimeError(f'dont know how to deal with {node!r}') - def parse_statement(self, parent, stmt): + def parse_statement(self, parent: Node, stmt: Node) -> t.Optional[Variable]: names, annotation, value = self._split_statement(stmt) + data: t.Optional[Variable] = None if value or annotation: docstring = self.get_statement_docstring(stmt) expr = self.nodes_to_string(value) if value else None - annotation = self.nodes_to_string(annotation) if annotation else None + annotation_as_string = self.nodes_to_string(annotation) if annotation else None assert names + # NOTE (@nrser) Does this have some sort of side-effect from creating + # the `Variable` instance? Why loop versus directly use `names[-1]`? for name in names: name = self.nodes_to_string([name]) data = Variable( name=name, location=self.location_from(stmt), docstring=docstring, - datatype=annotation, + datatype=annotation_as_string, value=expr, ) - return data - return None + return data + + def parse_decorator(self, node: Node): + assert get_value(node.children[0]) == '@' + + # NOTE (@nrser)I have no idea why `blib2to3` parses some decorators with a 'power' + # node (which _seems_ refer to the exponent operator `**`), but it + # does. + # + # The hint I eventually found was: + # + # https://github.com/psf/black/blob/b0d1fba7ac3be53c71fb0d3211d911e629f8aecb/src/black/nodes.py#L657 + # + # Anyways, this works around that curiosity. + if node.children[1].type == syms.power: + name = self.name_to_string(node.children[1].children[0]) + call_expr = self.nodes_to_string(node.children[1].children[1:]).strip() + + else: + name = self.name_to_string(node.children[1]) + call_expr = self.nodes_to_string(node.children[2:]).strip() - def parse_decorator(self, node): - assert node.children[0].value == '@' - name = self.name_to_string(node.children[1]) - call_expr = self.nodes_to_string(node.children[2:]).strip() return Decoration(location=self.location_from(node), name=name, args=call_expr or None) - def parse_funcdef(self, parent, node, is_async, decorations): - parameters = find(lambda x: x.type == syms.parameters, node.children) - body = find(lambda x: x.type == syms.suite, node.children) or \ - find(lambda x: x.type == syms.simple_stmt, node.children) + def parse_funcdef( + self, + parent: Node, + node: Node, + is_async: bool, + decorations: t.Optional[list[Decoration]] + ) -> Function: + parameters = get(lambda x: x.type == syms.parameters, node.children, as_type=Node) + body = ( + find(lambda x: x.type == syms.suite, node.children, as_type=Node) + or get(lambda x: x.type == syms.simple_stmt, node.children, as_type=Node) + ) - name = node.children[1].value + name = get_value(node.children[1]) docstring = self.get_docstring_from_first_node(body) args = self.parse_parameters(parameters) return_ = self.get_return_annotation(node) @@ -267,23 +480,28 @@ def parse_funcdef(self, parent, node, is_async, decorations): return_type=return_, decorations=decorations) - def parse_argument(self, node: t.Union[Leaf, Node, None], argtype: Argument.Type, scanner: 'SequenceWalker[Leaf | Node]') -> Argument: + def parse_argument( + self, + node: t.Optional[NL], + argtype: Argument.Type, + scanner: SequenceWalker[NL], + ) -> Argument: """ Parses an argument from the AST. *node* must be the current node at the current position of the *scanner*. The scanner is used to extract the optional default argument value that follows the *node*. """ - def parse_annotated_name(node): - if node.type == syms.tname: + def parse_annotated_name(node: NL) -> tuple[str, t.Optional[str]]: + if node.type in (syms.tname, syms.tname_star): scanner = SequenceWalker(node.children) - name = scanner.current.value + name = get_value(scanner.current) node = scanner.next() assert node.type == token.COLON, node.parent node = scanner.next() annotation = self.nodes_to_string([node]) elif node: - name = node.value + name = get_value(node) annotation = None else: raise RuntimeError('unexpected node: {!r}'.format(node)) @@ -298,6 +516,7 @@ def parse_annotated_name(node): default = None if node and node.type == token.EQUAL: node = scanner.advance() + assert node is not None default = self.nodes_to_string([node]) scanner.advance() @@ -309,7 +528,7 @@ def parse_annotated_name(node): default_value=default, ) - def parse_parameters(self, parameters): + def parse_parameters(self, parameters: Node) -> list[Argument]: assert parameters.type == syms.parameters, parameters.type result: t.List[Argument] = [] @@ -327,7 +546,7 @@ def parse_parameters(self, parameters): if len(parameters.children) == 3: result.append(Argument( location=self.location_from(parameters.children[1]), - name=parameters.children[1].value, + name=get_value(parameters.children[1]), type=Argument.Type.POSITIONAL, decorations=None, datatype=None, @@ -347,8 +566,14 @@ def parse_parameters(self, parameters): for arg in result: assert arg.type == Argument.Type.POSITIONAL, arg arg.type = Argument.Type.POSITIONAL_ONLY - node = index.next() - if node.type == token.COMMA: + # There may not be another token after the '/' -- seems like it totally + # works to define a function like + # + # def f(x, y, /): + # ... + # + node = index.advance() + if node is not None and node.type == token.COMMA: index.advance() elif node.type == token.STAR: @@ -405,8 +630,13 @@ def parse_classdef_rawargs(self, classdef): index.next() return metaclass, bases - def parse_classdef(self, parent, node, decorations): - name = node.children[1].value + def parse_classdef( + self, + parent: Node, + node: Node, + decorations: t.Optional[list[Decoration]] + ) -> Class: + name = get_value(node.children[1]) bases = [] metaclass = None @@ -445,17 +675,23 @@ def parse_classdef(self, parent, node, decorations): class_.metaclass = metaclass return class_ - def location_from(self, node: t.Union[Node, Leaf]) -> Location: - return Location(self.filename, node.get_lineno()) + def location_from(self, node: NL) -> Location: + # NOTE (@nrser) `blib2to3.pytree.Base.get_lineno` may return `None`, but + # `Location` expects an `int`, so not sure exactly what to do here... for + # the moment just return a bogus value of -1 + lineno = node.get_lineno() + if lineno is None: + lineno = -1 + return Location(self.filename, lineno) def get_return_annotation(self, node: Node) -> t.Optional[str]: rarrow = find(lambda x: x.type == token.RARROW, node.children) if rarrow: - node = rarrow.next_sibling - return self.nodes_to_string([node]) + assert rarrow.next_sibling # satisfy type checker + return self.nodes_to_string([rarrow.next_sibling]) return None - def get_most_recent_prefix(self, node) -> str: + def get_most_recent_prefix(self, node: NL) -> str: if node.prefix: return node.prefix while not node.prev_sibling and not node.prefix: @@ -464,30 +700,41 @@ def get_most_recent_prefix(self, node) -> str: node = node.parent if node.prefix: return node.prefix - node = node.prev_sibling - while isinstance(node, Node) and node.children: - node = node.children[-1] + while isinstance(node.prev_sibling, Node) and node.prev_sibling.children: + node = node.prev_sibling.children[-1] return node.prefix - def get_docstring_from_first_node(self, parent: Node, module_level: bool = False) -> t.Optional[Docstring]: + def get_docstring_from_first_node( + self, + parent: NL, + module_level: bool = False + ) -> t.Optional[Docstring]: """ This method retrieves the docstring for the block node *parent*. The node either declares a class or function. """ assert parent is not None - node = find(lambda x: isinstance(x, Node), parent.children) - if node and node.type == syms.simple_stmt and node.children[0].type == token.STRING: - return self.prepare_docstring(node.children[0].value, parent) + node = find(is_node, parent.children) + + if ( + node + and node.type == syms.simple_stmt + and node.children[0].type == token.STRING + ): + return self.prepare_docstring(get_value(node.children[0]), parent) + if not node and not module_level: return None + if self.options.treat_singleline_comment_blocks_as_docstrings: docstring, doc_type = self.get_hashtag_docstring_from_prefix(node or parent) if doc_type == 'block': return docstring + return None - def get_statement_docstring(self, node: Node) -> t.Optional[Docstring]: + def get_statement_docstring(self, node: NL) -> t.Optional[Docstring]: prefix = self.get_most_recent_prefix(node) match = re.match(r'\s*', prefix[::-1]) assert match is not None @@ -497,7 +744,7 @@ def get_statement_docstring(self, node: Node) -> t.Optional[Docstring]: if doc_type == 'statement': return docstring # Look for the next string literal instead. - curr: t.Optional[Node] = node + curr: t.Optional[NL] = node while curr and curr.type != syms.simple_stmt: curr = curr.parent if curr and curr.next_sibling and curr.next_sibling.type == syms.simple_stmt: @@ -507,7 +754,10 @@ def get_statement_docstring(self, node: Node) -> t.Optional[Docstring]: return self.prepare_docstring(string_literal.value, string_literal) return None - def get_hashtag_docstring_from_prefix(self, node: Node) -> t.Tuple[t.Optional[Docstring], t.Optional[str]]: + def get_hashtag_docstring_from_prefix( + self, + node: NL, + ) -> t.Tuple[t.Optional[Docstring], t.Optional[str]]: """ Given a node in the AST, this method retrieves the docstring from the closest prefix of this node (ie. any block of single-line comments that @@ -536,7 +786,7 @@ def get_hashtag_docstring_from_prefix(self, node: Node) -> t.Tuple[t.Optional[Do return self.prepare_docstring('\n'.join(reversed(lines)), node), doc_type - def prepare_docstring(self, s: str, node_for_location: t.Union[Node, Leaf]) -> t.Optional[Docstring]: + def prepare_docstring(self, s: str, node_for_location: NL) -> t.Optional[Docstring]: # TODO @NiklasRosenstein handle u/f prefixes of string literal? location = self.location_from(node_for_location) s = s.strip() @@ -557,12 +807,12 @@ def prepare_docstring(self, s: str, node_for_location: t.Union[Node, Leaf]) -> t return Docstring(location, dedent_docstring(s[1:-1]).strip()) return None - def nodes_to_string(self, nodes): + def nodes_to_string(self, nodes: list[NL]) -> str: """ Converts a list of AST nodes to a string. """ - def generator(nodes: t.List[t.Union[Node, Leaf]], skip_prefix: bool = True) -> t.Iterable[str]: + def generator(nodes: t.List[NL], skip_prefix: bool = True) -> t.Iterable[str]: for i, node in enumerate(nodes): if not skip_prefix or i != 0: yield node.prefix @@ -573,8 +823,9 @@ def generator(nodes: t.List[t.Union[Node, Leaf]], skip_prefix: bool = True) -> t return ''.join(generator(nodes)) - def name_to_string(self, node): + def name_to_string(self, node: NL) -> str: if node.type == syms.dotted_name: - return ''.join(x.value for x in node.children) + return ''.join(get_value(x) for x in node.children) else: - return node.value + return get_value(node) + diff --git a/docspec-python/test/test_parser.py b/docspec-python/test/test_parser.py index 01ed949..ecedc07 100644 --- a/docspec-python/test/test_parser.py +++ b/docspec-python/test/test_parser.py @@ -472,3 +472,29 @@ def build_docker_image( return_type="Task" ), ] + +@docspec_test() +def test_funcdef_with_match_statement(): + """ + def f(x): + match x: + case str(s): + return "string" + case Path() as p: + return "path" + case int(n) | float(n): + return "number" + case _: + return "idk" + """ + + return [ + mkfunc( + "f", + None, + 0, + [ + Argument(loc, "x", Argument.Type.POSITIONAL, None), + ], + ), + ] diff --git a/docspec/src/docspec/__init__.py b/docspec/src/docspec/__init__.py index debbbc1..802d406 100644 --- a/docspec/src/docspec/__init__.py +++ b/docspec/src/docspec/__init__.py @@ -470,7 +470,7 @@ def load_module( # we ar sure the type is "IO" since the source has a read attribute. source = loader(source) # type: ignore[arg-type] - module = databind.json.load(source, Module, filename=filename) + module = databind.json.load(source, Module, filename=filename or '') module.sync_hierarchy() return module @@ -496,7 +496,7 @@ def load_modules( source = (loader(io.StringIO(line)) for line in t.cast(t.IO[str], source)) for data in source: - module = databind.json.load(data, Module, filename=filename) + module = databind.json.load(data, Module, filename=filename or '') module.sync_hierarchy() yield module