Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Support requires-clause in more places #10286

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
113 changes: 76 additions & 37 deletions sphinx/domains/cpp.py
Expand Up @@ -3659,24 +3659,39 @@ def describe_signature(self, signode: TextElement, mode: str,


class ASTTemplateParams(ASTBase):
def __init__(self, params: List[ASTTemplateParam]) -> None:
def __init__(self, params: List[ASTTemplateParam],
requiresClause: Optional["ASTRequiresClause"]) -> None:
assert params is not None
self.params = params
self.requiresClause = requiresClause

def get_id(self, version: int) -> str:
def get_id(self, version: int, exclude_requires: bool = False) -> str:
jbms marked this conversation as resolved.
Show resolved Hide resolved
# Note: For `version==4`, `exclude_requires` is set to `True` when
# encoding the id of the last template parameter list of a declaration,
# as that requires-clause, if any, is instead encoded by
# `ASTDeclaration.get_id` after encoding the template prefix, for
# consistency with the existing v4 format used when only a single
# requires-clause was supported.
assert version >= 2
res = []
res.append("I")
for param in self.params:
res.append(param.get_id(version))
res.append("E")
if not exclude_requires and self.requiresClause:
res.append('IQ')
res.append(self.requiresClause.expr.get_id(version))
res.append('E')
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
res = []
res.append("template<")
res.append(", ".join(transform(a) for a in self.params))
res.append("> ")
if self.requiresClause is not None:
res.append(transform(self.requiresClause))
res.append(" ")
return ''.join(res)

def describe_signature(self, signode: TextElement, mode: str,
Expand All @@ -3691,6 +3706,9 @@ def describe_signature(self, signode: TextElement, mode: str,
first = False
param.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
signode += addnodes.desc_sig_space()
self.requiresClause.describe_signature(signode, mode, env, symbol)

def describe_signature_as_introducer(
self, parentNode: desc_signature, mode: str, env: "BuildEnvironment",
Expand All @@ -3715,6 +3733,11 @@ def makeLine(parentNode: desc_signature) -> addnodes.desc_signature_line:
if lineSpec and not first:
lineNode = makeLine(parentNode)
lineNode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
parentNode += reqNode
self.requiresClause.describe_signature(reqNode, 'markType', env, symbol)


# Template introducers
Expand Down Expand Up @@ -3837,8 +3860,12 @@ def get_id(self, version: int) -> str:
assert version >= 2
# this is not part of a normal name mangling system
res = []
for t in self.templates:
res.append(t.get_id(version))
last_index = len(self.templates) - 1
for i, t in enumerate(self.templates):
if isinstance(t, ASTTemplateParams):
res.append(t.get_id(version, exclude_requires=(i == last_index)))
else:
res.append(t.get_id(version))
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
Expand All @@ -3861,7 +3888,7 @@ def __init__(self, expr: ASTExpression) -> None:
def _stringify(self, transform: StringifyTransform) -> str:
return 'requires ' + transform(self.expr)

def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
def describe_signature(self, signode: nodes.TextElement, mode: str,
env: "BuildEnvironment", symbol: "Symbol") -> None:
signode += addnodes.desc_sig_keyword('requires', 'requires')
signode += addnodes.desc_sig_space()
Expand All @@ -3872,16 +3899,16 @@ def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
################################################################################

class ASTDeclaration(ASTBase):
def __init__(self, objectType: str, directiveType: str, visibility: str,
templatePrefix: ASTTemplateDeclarationPrefix,
requiresClause: ASTRequiresClause, declaration: Any,
trailingRequiresClause: ASTRequiresClause,
def __init__(self, objectType: str, directiveType: Optional[str] = None,
visibility: Optional[str] = None,
templatePrefix: Optional[ASTTemplateDeclarationPrefix] = None,
declaration: Any = None,
trailingRequiresClause: Optional[ASTRequiresClause] = None,
semicolon: bool = False) -> None:
self.objectType = objectType
self.directiveType = directiveType
self.visibility = visibility
self.templatePrefix = templatePrefix
self.requiresClause = requiresClause
self.declaration = declaration
self.trailingRequiresClause = trailingRequiresClause
self.semicolon = semicolon
Expand All @@ -3892,18 +3919,29 @@ def __init__(self, objectType: str, directiveType: str, visibility: str,

def clone(self) -> "ASTDeclaration":
templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None
requiresClasueClone = self.requiresClause.clone() if self.requiresClause else None
trailingRequiresClasueClone = self.trailingRequiresClause.clone() \
if self.trailingRequiresClause else None
return ASTDeclaration(self.objectType, self.directiveType, self.visibility,
templatePrefixClone, requiresClasueClone,
templatePrefixClone,
self.declaration.clone(), trailingRequiresClasueClone,
self.semicolon)

@property
def name(self) -> ASTNestedName:
return self.declaration.name

@property
def requiresClause(self) -> Optional[ASTRequiresClause]:
templatePrefix = self.templatePrefix
if templatePrefix is None:
return None
if not templatePrefix.templates:
return None
last_template = templatePrefix.templates[-1]
if not isinstance(last_template, ASTTemplateParams):
return None
return last_template.requiresClause

@property
def function_params(self) -> List[ASTFunctionParameter]:
if self.objectType != 'function':
Expand All @@ -3912,7 +3950,7 @@ def function_params(self) -> List[ASTFunctionParameter]:

def get_id(self, version: int, prefixed: bool = True) -> str:
if version == 1:
if self.templatePrefix:
if self.templatePrefix or self.trailingRequiresClause:
raise NoOldIdError()
if self.objectType == 'enumerator' and self.enumeratorScopedSymbol:
return self.enumeratorScopedSymbol.declaration.get_id(version)
Expand All @@ -3926,14 +3964,17 @@ def get_id(self, version: int, prefixed: bool = True) -> str:
res = []
if self.templatePrefix:
res.append(self.templatePrefix.get_id(version))
if self.requiresClause or self.trailingRequiresClause:
# Encode the last requires clause specially to avoid introducing a new
# id version number.
requiresClause = self.requiresClause
if requiresClause or self.trailingRequiresClause:
if version < 4:
raise NoOldIdError()
res.append('IQ')
if self.requiresClause and self.trailingRequiresClause:
if requiresClause and self.trailingRequiresClause:
res.append('aa')
if self.requiresClause:
res.append(self.requiresClause.expr.get_id(version))
if requiresClause:
res.append(requiresClause.expr.get_id(version))
if self.trailingRequiresClause:
res.append(self.trailingRequiresClause.expr.get_id(version))
res.append('E')
Expand All @@ -3950,9 +3991,6 @@ def _stringify(self, transform: StringifyTransform) -> str:
res.append(' ')
if self.templatePrefix:
res.append(transform(self.templatePrefix))
if self.requiresClause:
res.append(transform(self.requiresClause))
res.append(' ')
res.append(transform(self.declaration))
if self.trailingRequiresClause:
res.append(' ')
Expand All @@ -3977,11 +4015,6 @@ def describe_signature(self, signode: desc_signature, mode: str,
self.templatePrefix.describe_signature(signode, mode, env,
symbol=self.symbol,
lineSpec=options.get('tparam-line-spec'))
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
signode.append(reqNode)
self.requiresClause.describe_signature(reqNode, 'markType', env, self.symbol)
signode += mainDeclNode
if self.visibility and self.visibility != "public":
mainDeclNode += addnodes.desc_sig_keyword(self.visibility, self.visibility)
Expand Down Expand Up @@ -4164,7 +4197,7 @@ def _add_template_and_function_params(self) -> None:
continue
# only add a declaration if we our self are from a declaration
if self.declaration:
decl = ASTDeclaration('templateParam', None, None, None, None, tp, None)
decl = ASTDeclaration(objectType='templateParam', declaration=tp)
else:
decl = None
nne = ASTNestedNameElement(tp.get_identifier(), None)
Expand All @@ -4179,7 +4212,7 @@ def _add_template_and_function_params(self) -> None:
if nn is None:
continue
# (comparing to the template params: we have checked that we are a declaration)
decl = ASTDeclaration('functionParam', None, None, None, None, fp, None)
decl = ASTDeclaration(objectType='functionParam', declaration=fp)
assert not nn.rooted
assert len(nn.names) == 1
self._add_symbols(nn, [], decl, self.docname, self.line)
Expand Down Expand Up @@ -6475,7 +6508,14 @@ def _parse_type(self, named: Union[bool, str], outer: str = None) -> ASTType:
declSpecs = self._parse_decl_specs(outer=outer, typed=False)
decl = self._parse_declarator(named=True, paramMode=outer,
typed=False)
self.assert_end(allowSemicolon=True)
must_end = True
if outer == 'function':
# Allow trailing requires on constructors
self.skip_ws()
if re.compile(r'requires\b').match(self.definition, self.pos):
must_end = False
if must_end:
self.assert_end(allowSemicolon=True)
except DefinitionError as exUntyped:
if outer == 'type':
desc = "If just a name"
Expand Down Expand Up @@ -6747,7 +6787,8 @@ def _parse_template_parameter_list(self) -> ASTTemplateParams:
err = eParam
self.skip_ws()
if self.skip_string('>'):
return ASTTemplateParams(templateParams)
requiresClause = self._parse_requires_clause()
return ASTTemplateParams(templateParams, requiresClause)
elif self.skip_string(','):
continue
else:
Expand Down Expand Up @@ -6869,6 +6910,8 @@ def _parse_template_declaration_prefix(self, objectType: str
return ASTTemplateDeclarationPrefix(None)
else:
raise e
if objectType == 'concept' and params.requiresClause is not None:
self.fail('requires-clause not allowed for concept')
else:
params = self._parse_template_introduction()
if not params:
Expand Down Expand Up @@ -6917,7 +6960,7 @@ def _check_template_consistency(self, nestedName: ASTNestedName,

newTemplates: List[Union[ASTTemplateParams, ASTTemplateIntroduction]] = []
for _i in range(numExtra):
newTemplates.append(ASTTemplateParams([]))
newTemplates.append(ASTTemplateParams([], requiresClause=None))
if templatePrefix and not isMemberInstantiation:
newTemplates.extend(templatePrefix.templates)
templatePrefix = ASTTemplateDeclarationPrefix(newTemplates)
Expand All @@ -6933,18 +6976,15 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
raise Exception('Internal error, unknown directiveType "%s".' % directiveType)
visibility = None
templatePrefix = None
requiresClause = None
trailingRequiresClause = None
declaration: Any = None

self.skip_ws()
if self.match(_visibility_re):
visibility = self.matched_text

if objectType in ('type', 'concept', 'member', 'function', 'class'):
if objectType in ('type', 'concept', 'member', 'function', 'class', 'union'):
templatePrefix = self._parse_template_declaration_prefix(objectType)
if objectType == 'function' and templatePrefix is not None:
requiresClause = self._parse_requires_clause()

if objectType == 'type':
prevErrors = []
Expand All @@ -6970,8 +7010,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
declaration = self._parse_type_with_init(named=True, outer='member')
elif objectType == 'function':
declaration = self._parse_type(named=True, outer='function')
if templatePrefix is not None:
trailingRequiresClause = self._parse_requires_clause()
trailingRequiresClause = self._parse_requires_clause()
elif objectType == 'class':
declaration = self._parse_class()
elif objectType == 'union':
Expand All @@ -6989,7 +7028,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
self.skip_ws()
semicolon = self.skip_string(';')
return ASTDeclaration(objectType, directiveType, visibility,
templatePrefix, requiresClause, declaration,
templatePrefix, declaration,
trailingRequiresClause, semicolon)

def parse_namespace_object(self) -> ASTNamespace:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_domain_cpp.py
Expand Up @@ -896,8 +896,33 @@ def test_domain_cpp_ast_requires_clauses():
{4: 'I0EIQaa1A1BE1fvv'})
check('function', 'template<typename T> requires A || B or C void f()',
{4: 'I0EIQoo1Aoo1B1CE1fvv'})
check('function', 'void f() requires A || B || C',
{4: 'IQoo1Aoo1B1CE1fv'})
check('function', 'Foo() requires A || B || C',
{4: 'IQoo1Aoo1B1CE3Foov'})
check('function', 'template<typename T> requires A && B || C and D void f()',
{4: 'I0EIQooaa1A1Baa1C1DE1fvv'})
check('function',
'template<typename T> requires R<T> ' +
'template<typename U> requires S<T> ' +
'void A<T>::f() requires B',
{4: 'I0EIQ1RI1TEEI0EIQaa1SI1TE1BEN1AI1TE1fEvv'})
check('function',
'template<template<typename T> requires R<T> typename X> ' +
'void f()',
{2: 'II0EIQ1RI1TEE0E1fv', 4: 'II0EIQ1RI1TEE0E1fvv'})
check('type',
'template<typename T> requires IsValid<T> {key}T = true_type',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='using')
check('class',
'template<typename T> requires IsValid<T> {key}T : Base',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='class')
check('union',
'template<typename T> requires IsValid<T> {key}T',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='union')
check('member',
'template<typename T> requires IsValid<T> int Val = 7',
{4: 'I0EIQ7IsValidI1TEE3Val'})


def test_domain_cpp_ast_template_args():
Expand Down