Skip to content

Commit

Permalink
[C++] Support requires-clause in more places
Browse files Browse the repository at this point in the history
Previously a C++20 requires-clause was only supported on `function`
declarations.  However, the C++ standard allows a require-clause on
class/union templates, alias templates, and variable templates, and
also allows a requires clause after each template parameter list, not
just the final one.

This moves the requiresClause to be a property of `ASTTemplateParams`
rather than `ASTDeclaration` to better match the C++ grammar and
allows requires clauses in many places that are supported by C++20 but
were not previously allowed by Sphinx, namely:

- On class/union templates, alias templates, and variable templates

- After each template parameter list, not just the last one.

- After the template parameter list in template template parameters.

Additionally:

- This adds support for template parameters on unions.

- This adds support for trailing requires clauses on functions without
  a template prefix.  This is allowed by C++20 for non-template members
  of class templates.

When encoding the id, the requires clause of the last template
parameter list is treated specially in order to preserve compatibility
with existing v4 ids.
  • Loading branch information
jbms committed Mar 23, 2022
1 parent 670e8b1 commit c1dce07
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 36 deletions.
105 changes: 69 additions & 36 deletions sphinx/domains/cpp.py
Expand Up @@ -3659,24 +3659,33 @@ def describe_signature(self, signode: TextElement, mode: str,


class ASTTemplateParams(ASTBase):
def __init__(self, params: List[ASTTemplateParam]) -> None:
def __init__(self, params: List[ASTTemplateParam],
requiresClause: Optional["ASTRequiresClause"] = None) -> None:
assert params is not None
self.params = params
self.requiresClause = requiresClause

def get_id(self, version: int) -> str:
def get_id(self, version: int, exclude_requires: bool = False) -> str:
assert version >= 2
res = []
res.append("I")
for param in self.params:
res.append(param.get_id(version))
res.append("E")
if not exclude_requires and self.requiresClause:
res.append('IQ')
res.append(self.requiresClause.expr.get_id(version))
res.append('E')
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
res = []
res.append("template<")
res.append(", ".join(transform(a) for a in self.params))
res.append("> ")
if self.requiresClause is not None:
res.append(transform(self.requiresClause))
res.append(" ")
return ''.join(res)

def describe_signature(self, signode: TextElement, mode: str,
Expand All @@ -3691,6 +3700,9 @@ def describe_signature(self, signode: TextElement, mode: str,
first = False
param.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
signode += addnodes.desc_sig_space()
self.requiresClause.describe_signature(signode, mode, env, symbol)

def describe_signature_as_introducer(
self, parentNode: desc_signature, mode: str, env: "BuildEnvironment",
Expand All @@ -3715,6 +3727,11 @@ def makeLine(parentNode: desc_signature) -> addnodes.desc_signature_line:
if lineSpec and not first:
lineNode = makeLine(parentNode)
lineNode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
parentNode += reqNode
self.requiresClause.describe_signature(reqNode, 'markType', env, symbol)


# Template introducers
Expand Down Expand Up @@ -3837,8 +3854,12 @@ def get_id(self, version: int) -> str:
assert version >= 2
# this is not part of a normal name mangling system
res = []
for t in self.templates:
res.append(t.get_id(version))
last_index = len(self.templates) - 1
for i, t in enumerate(self.templates):
if isinstance(t, ASTTemplateParams):
res.append(t.get_id(version, exclude_requires=(i == last_index)))
else:
res.append(t.get_id(version))
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
Expand All @@ -3861,7 +3882,7 @@ def __init__(self, expr: ASTExpression) -> None:
def _stringify(self, transform: StringifyTransform) -> str:
return 'requires ' + transform(self.expr)

def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
def describe_signature(self, signode: nodes.TextElement, mode: str,
env: "BuildEnvironment", symbol: "Symbol") -> None:
signode += addnodes.desc_sig_keyword('requires', 'requires')
signode += addnodes.desc_sig_space()
Expand All @@ -3872,17 +3893,19 @@ def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
################################################################################

class ASTDeclaration(ASTBase):
def __init__(self, objectType: str, directiveType: str, visibility: str,
templatePrefix: ASTTemplateDeclarationPrefix,
requiresClause: ASTRequiresClause, declaration: Any,
trailingRequiresClause: ASTRequiresClause,
def __init__(self, objectType: str, directiveType: Optional[str] = None,
visibility: Optional[str] = None,
templatePrefix: Optional[ASTTemplateDeclarationPrefix] = None,
declaration: Any = None,
trailingRequiresClause: Optional[ASTRequiresClause] = None,
semicolon: bool = False) -> None:
self.objectType = objectType
self.directiveType = directiveType
self.visibility = visibility
self.templatePrefix = templatePrefix
self.requiresClause = requiresClause
self.declaration = declaration
if trailingRequiresClause is not None:
assert isinstance(trailingRequiresClause, ASTRequiresClause)
self.trailingRequiresClause = trailingRequiresClause
self.semicolon = semicolon

Expand All @@ -3892,18 +3915,29 @@ def __init__(self, objectType: str, directiveType: str, visibility: str,

def clone(self) -> "ASTDeclaration":
templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None
requiresClasueClone = self.requiresClause.clone() if self.requiresClause else None
trailingRequiresClasueClone = self.trailingRequiresClause.clone() \
if self.trailingRequiresClause else None
return ASTDeclaration(self.objectType, self.directiveType, self.visibility,
templatePrefixClone, requiresClasueClone,
templatePrefixClone,
self.declaration.clone(), trailingRequiresClasueClone,
self.semicolon)

@property
def name(self) -> ASTNestedName:
return self.declaration.name

@property
def requiresClause(self) -> Optional[ASTRequiresClause]:
templatePrefix = self.templatePrefix
if templatePrefix is None:
return None
if not templatePrefix.templates:
return None
last_template = templatePrefix.templates[-1]
if not isinstance(last_template, ASTTemplateParams):
return None
return last_template.requiresClause

@property
def function_params(self) -> List[ASTFunctionParameter]:
if self.objectType != 'function':
Expand All @@ -3912,7 +3946,7 @@ def function_params(self) -> List[ASTFunctionParameter]:

def get_id(self, version: int, prefixed: bool = True) -> str:
if version == 1:
if self.templatePrefix:
if self.templatePrefix or self.trailingRequiresClause:
raise NoOldIdError()
if self.objectType == 'enumerator' and self.enumeratorScopedSymbol:
return self.enumeratorScopedSymbol.declaration.get_id(version)
Expand All @@ -3926,14 +3960,17 @@ def get_id(self, version: int, prefixed: bool = True) -> str:
res = []
if self.templatePrefix:
res.append(self.templatePrefix.get_id(version))
if self.requiresClause or self.trailingRequiresClause:
# Encode the last requires clause specially to avoid introducing a new
# id version number.
requiresClause = self.requiresClause
if requiresClause or self.trailingRequiresClause:
if version < 4:
raise NoOldIdError()
res.append('IQ')
if self.requiresClause and self.trailingRequiresClause:
if requiresClause and self.trailingRequiresClause:
res.append('aa')
if self.requiresClause:
res.append(self.requiresClause.expr.get_id(version))
if requiresClause:
res.append(requiresClause.expr.get_id(version))
if self.trailingRequiresClause:
res.append(self.trailingRequiresClause.expr.get_id(version))
res.append('E')
Expand All @@ -3950,9 +3987,6 @@ def _stringify(self, transform: StringifyTransform) -> str:
res.append(' ')
if self.templatePrefix:
res.append(transform(self.templatePrefix))
if self.requiresClause:
res.append(transform(self.requiresClause))
res.append(' ')
res.append(transform(self.declaration))
if self.trailingRequiresClause:
res.append(' ')
Expand All @@ -3977,11 +4011,6 @@ def describe_signature(self, signode: desc_signature, mode: str,
self.templatePrefix.describe_signature(signode, mode, env,
symbol=self.symbol,
lineSpec=options.get('tparam-line-spec'))
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
signode.append(reqNode)
self.requiresClause.describe_signature(reqNode, 'markType', env, self.symbol)
signode += mainDeclNode
if self.visibility and self.visibility != "public":
mainDeclNode += addnodes.desc_sig_keyword(self.visibility, self.visibility)
Expand Down Expand Up @@ -4164,7 +4193,7 @@ def _add_template_and_function_params(self) -> None:
continue
# only add a declaration if we our self are from a declaration
if self.declaration:
decl = ASTDeclaration('templateParam', None, None, None, None, tp, None)
decl = ASTDeclaration(objectType='templateParam', declaration=tp)
else:
decl = None
nne = ASTNestedNameElement(tp.get_identifier(), None)
Expand All @@ -4179,7 +4208,7 @@ def _add_template_and_function_params(self) -> None:
if nn is None:
continue
# (comparing to the template params: we have checked that we are a declaration)
decl = ASTDeclaration('functionParam', None, None, None, None, fp, None)
decl = ASTDeclaration(objectType='functionParam', declaration=fp)
assert not nn.rooted
assert len(nn.names) == 1
self._add_symbols(nn, [], decl, self.docname, self.line)
Expand Down Expand Up @@ -6475,7 +6504,14 @@ def _parse_type(self, named: Union[bool, str], outer: str = None) -> ASTType:
declSpecs = self._parse_decl_specs(outer=outer, typed=False)
decl = self._parse_declarator(named=True, paramMode=outer,
typed=False)
self.assert_end(allowSemicolon=True)
must_end = True
if outer == 'function':
# Allow trailing requires on constructors
self.skip_ws()
if re.compile(r'requires\b').match(self.definition, self.pos):
must_end = False
if must_end:
self.assert_end(allowSemicolon=True)
except DefinitionError as exUntyped:
if outer == 'type':
desc = "If just a name"
Expand Down Expand Up @@ -6747,7 +6783,8 @@ def _parse_template_parameter_list(self) -> ASTTemplateParams:
err = eParam
self.skip_ws()
if self.skip_string('>'):
return ASTTemplateParams(templateParams)
requiresClause = self._parse_requires_clause()
return ASTTemplateParams(templateParams, requiresClause)
elif self.skip_string(','):
continue
else:
Expand Down Expand Up @@ -6933,18 +6970,15 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
raise Exception('Internal error, unknown directiveType "%s".' % directiveType)
visibility = None
templatePrefix = None
requiresClause = None
trailingRequiresClause = None
declaration: Any = None

self.skip_ws()
if self.match(_visibility_re):
visibility = self.matched_text

if objectType in ('type', 'concept', 'member', 'function', 'class'):
if objectType in ('type', 'concept', 'member', 'function', 'class', 'union'):
templatePrefix = self._parse_template_declaration_prefix(objectType)
if objectType == 'function' and templatePrefix is not None:
requiresClause = self._parse_requires_clause()

if objectType == 'type':
prevErrors = []
Expand All @@ -6970,8 +7004,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
declaration = self._parse_type_with_init(named=True, outer='member')
elif objectType == 'function':
declaration = self._parse_type(named=True, outer='function')
if templatePrefix is not None:
trailingRequiresClause = self._parse_requires_clause()
trailingRequiresClause = self._parse_requires_clause()
elif objectType == 'class':
declaration = self._parse_class()
elif objectType == 'union':
Expand All @@ -6989,7 +7022,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
self.skip_ws()
semicolon = self.skip_string(';')
return ASTDeclaration(objectType, directiveType, visibility,
templatePrefix, requiresClause, declaration,
templatePrefix, declaration,
trailingRequiresClause, semicolon)

def parse_namespace_object(self) -> ASTNamespace:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_domain_cpp.py
Expand Up @@ -896,8 +896,33 @@ def test_domain_cpp_ast_requires_clauses():
{4: 'I0EIQaa1A1BE1fvv'})
check('function', 'template<typename T> requires A || B or C void f()',
{4: 'I0EIQoo1Aoo1B1CE1fvv'})
check('function', 'void f() requires A || B || C',
{4: 'IQoo1Aoo1B1CE1fv'})
check('function', 'Foo() requires A || B || C',
{4: 'IQoo1Aoo1B1CE3Foov'})
check('function', 'template<typename T> requires A && B || C and D void f()',
{4: 'I0EIQooaa1A1Baa1C1DE1fvv'})
check('function',
'template<typename T> requires R<T> ' +
'template<typename U> requires S<T> ' +
'void A<T>::f() requires B',
{4: 'I0EIQ1RI1TEEI0EIQaa1SI1TE1BEN1AI1TE1fEvv'})
check('function',
'template<template<typename T> requires R<T> typename X> ' +
'void f()',
{2: 'II0EIQ1RI1TEE0E1fv', 4: 'II0EIQ1RI1TEE0E1fvv'})
check('type',
'template<typename T> requires IsValid<T> {key}T = true_type',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='using')
check('class',
'template<typename T> requires IsValid<T> {key}T : Base',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='class')
check('union',
'template<typename T> requires IsValid<T> {key}T',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='union')
check('member',
'template<typename T> requires IsValid<T> int Val = 7',
{4: 'I0EIQ7IsValidI1TEE3Val'})


def test_domain_cpp_ast_template_args():
Expand Down

0 comments on commit c1dce07

Please sign in to comment.