diff --git a/CHANGES.rst b/CHANGES.rst index 6cee76fbc91..b2437254b06 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -18,6 +18,9 @@ Features added Patch by James Addison and Adam Turner .. _officially recommended: https://jinja.palletsprojects.com/en/latest/templates/#template-file-extension +* Flatten ``Union[Literal[T], Literal[U], ...]`` to ``Literal[T, U, ...]`` + when turning annotations into strings. + Patch by Adam Turner. Bugs fixed ---------- diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index 002c85641da..c8e89ee0ad6 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -439,6 +439,15 @@ def stringify_annotation( # They must be a list or a tuple, otherwise they are considered 'broken'. annotation_args = getattr(annotation, '__args__', ()) if annotation_args and isinstance(annotation_args, (list, tuple)): + if ( + qualname in {'Union', 'types.UnionType'} + and all(getattr(a, '__origin__', ...) is typing.Literal for a in annotation_args) + ): + # special case to flatten a Union of Literals into a literal + flattened_args = typing.Literal[annotation_args].__args__ # type: ignore[attr-defined] + args = ', '.join(_format_literal_arg_stringify(a, mode=mode) + for a in flattened_args) + return f'{module_prefix}Literal[{args}]' if qualname in {'Optional', 'Union', 'types.UnionType'}: return ' | '.join(stringify_annotation(a, mode) for a in annotation_args) elif qualname == 'Callable': diff --git a/tests/test_util/test_util_inspect.py b/tests/test_util/test_util_inspect.py index 83a6b72d974..764ca20d1de 100644 --- a/tests/test_util/test_util_inspect.py +++ b/tests/test_util/test_util_inspect.py @@ -359,6 +359,10 @@ def test_signature_annotations(): sig = inspect.signature(mod.f25) assert stringify_signature(sig) == '(a, b, /)' + # collapse Literal types + sig = inspect.signature(mod.f26) + assert stringify_signature(sig) == "(x: typing.Literal[1, 2, 3] = 1, y: typing.Literal['a', 'b'] = 'a') -> None" + def test_signature_from_str_basic(): signature = '(a, b, *args, c=0, d="blah", **kwargs)' diff --git a/tests/test_util/typing_test_data.py b/tests/test_util/typing_test_data.py index e29b60050eb..05888366e75 100644 --- a/tests/test_util/typing_test_data.py +++ b/tests/test_util/typing_test_data.py @@ -1,6 +1,6 @@ from inspect import Signature from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union def f0(x: int, y: Integral) -> None: @@ -121,6 +121,10 @@ def f25(a, b, /): pass +def f26(x: Literal[1, 2, 3] = 1, y: Union[Literal["a"], Literal["b"]] = "a") -> None: + pass + + class Node: def __init__(self, parent: Optional['Node']) -> None: pass