Skip to content

Commit

Permalink
Merge pull request #10703 from jakobandersen/cpp_requires_clause
Browse files Browse the repository at this point in the history
C++, improve requires clause support (#10286 update)
  • Loading branch information
jakobandersen committed Jul 24, 2022
2 parents 3c469c4 + feb4ac8 commit e867201
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 41 deletions.
3 changes: 3 additions & 0 deletions CHANGES
Expand Up @@ -13,6 +13,9 @@ Deprecated
Features added
--------------

- #10286: C++, support requires clauses not just between the template
parameter lists and the declaration.

Bugs fixed
----------

Expand Down
123 changes: 82 additions & 41 deletions sphinx/domains/cpp.py
Expand Up @@ -3687,24 +3687,33 @@ def describe_signature(self, signode: TextElement, mode: str,


class ASTTemplateParams(ASTBase):
def __init__(self, params: List[ASTTemplateParam]) -> None:
def __init__(self, params: List[ASTTemplateParam],
requiresClause: Optional["ASTRequiresClause"]) -> None:
assert params is not None
self.params = params
self.requiresClause = requiresClause

def get_id(self, version: int) -> str:
def get_id(self, version: int, excludeRequires: bool = False) -> str:
assert version >= 2
res = []
res.append("I")
for param in self.params:
res.append(param.get_id(version))
res.append("E")
if not excludeRequires and self.requiresClause:
res.append('IQ')
res.append(self.requiresClause.expr.get_id(version))
res.append('E')
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
res = []
res.append("template<")
res.append(", ".join(transform(a) for a in self.params))
res.append("> ")
if self.requiresClause is not None:
res.append(transform(self.requiresClause))
res.append(" ")
return ''.join(res)

def describe_signature(self, signode: TextElement, mode: str,
Expand All @@ -3719,6 +3728,9 @@ def describe_signature(self, signode: TextElement, mode: str,
first = False
param.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause is not None:
signode += addnodes.desc_sig_space()
self.requiresClause.describe_signature(signode, mode, env, symbol)

def describe_signature_as_introducer(
self, parentNode: desc_signature, mode: str, env: "BuildEnvironment",
Expand All @@ -3743,6 +3755,11 @@ def makeLine(parentNode: desc_signature) -> addnodes.desc_signature_line:
if lineSpec and not first:
lineNode = makeLine(parentNode)
lineNode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
parentNode += reqNode
self.requiresClause.describe_signature(reqNode, 'markType', env, symbol)


# Template introducers
Expand Down Expand Up @@ -3861,12 +3878,24 @@ def __init__(self,
# templates is None means it's an explicit instantiation of a variable
self.templates = templates

def get_id(self, version: int) -> str:
def get_requires_clause_in_last(self) -> Optional["ASTRequiresClause"]:
if self.templates is None:
return None
lastList = self.templates[-1]
if not isinstance(lastList, ASTTemplateParams):
return None
return lastList.requiresClause # which may be None

def get_id_except_requires_clause_in_last(self, version: int) -> str:
assert version >= 2
# this is not part of a normal name mangling system
# This is not part of the Itanium ABI mangling system.
res = []
for t in self.templates:
res.append(t.get_id(version))
lastIndex = len(self.templates) - 1
for i, t in enumerate(self.templates):
if isinstance(t, ASTTemplateParams):
res.append(t.get_id(version, excludeRequires=(i == lastIndex)))
else:
res.append(t.get_id(version))
return ''.join(res)

def _stringify(self, transform: StringifyTransform) -> str:
Expand All @@ -3889,7 +3918,7 @@ def __init__(self, expr: ASTExpression) -> None:
def _stringify(self, transform: StringifyTransform) -> str:
return 'requires ' + transform(self.expr)

def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
def describe_signature(self, signode: nodes.TextElement, mode: str,
env: "BuildEnvironment", symbol: "Symbol") -> None:
signode += addnodes.desc_sig_keyword('requires', 'requires')
signode += addnodes.desc_sig_space()
Expand All @@ -3900,16 +3929,16 @@ def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
################################################################################

class ASTDeclaration(ASTBase):
def __init__(self, objectType: str, directiveType: str, visibility: str,
templatePrefix: ASTTemplateDeclarationPrefix,
requiresClause: ASTRequiresClause, declaration: Any,
trailingRequiresClause: ASTRequiresClause,
def __init__(self, objectType: str, directiveType: Optional[str] = None,
visibility: Optional[str] = None,
templatePrefix: Optional[ASTTemplateDeclarationPrefix] = None,
declaration: Any = None,
trailingRequiresClause: Optional[ASTRequiresClause] = None,
semicolon: bool = False) -> None:
self.objectType = objectType
self.directiveType = directiveType
self.visibility = visibility
self.templatePrefix = templatePrefix
self.requiresClause = requiresClause
self.declaration = declaration
self.trailingRequiresClause = trailingRequiresClause
self.semicolon = semicolon
Expand All @@ -3920,11 +3949,10 @@ def __init__(self, objectType: str, directiveType: str, visibility: str,

def clone(self) -> "ASTDeclaration":
templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None
requiresClasueClone = self.requiresClause.clone() if self.requiresClause else None
trailingRequiresClasueClone = self.trailingRequiresClause.clone() \
if self.trailingRequiresClause else None
return ASTDeclaration(self.objectType, self.directiveType, self.visibility,
templatePrefixClone, requiresClasueClone,
templatePrefixClone,
self.declaration.clone(), trailingRequiresClasueClone,
self.semicolon)

Expand All @@ -3940,7 +3968,7 @@ def function_params(self) -> List[ASTFunctionParameter]:

def get_id(self, version: int, prefixed: bool = True) -> str:
if version == 1:
if self.templatePrefix:
if self.templatePrefix or self.trailingRequiresClause:
raise NoOldIdError()
if self.objectType == 'enumerator' and self.enumeratorScopedSymbol:
return self.enumeratorScopedSymbol.declaration.get_id(version)
Expand All @@ -3952,16 +3980,31 @@ def get_id(self, version: int, prefixed: bool = True) -> str:
res = [_id_prefix[version]]
else:
res = []
if self.templatePrefix:
res.append(self.templatePrefix.get_id(version))
if self.requiresClause or self.trailingRequiresClause:
# (See also https://github.com/sphinx-doc/sphinx/pull/10286#issuecomment-1168102147)
# The first implementation of requires clauses only supported a single clause after the
# template prefix, and no trailing clause. It put the ID after the template parameter
# list, i.e.,
# "I" + template_parameter_list_id + "E" + "IQ" + requires_clause_id + "E"
# but the second implementation associates the requires clause with each list, i.e.,
# "I" + template_parameter_list_id + "IQ" + requires_clause_id + "E" + "E"
# To avoid making a new ID version, we make an exception for the last requires clause
# in the template prefix, and still put it in the end.
# As we now support trailing requires clauses we add that as if it was a conjunction.
if self.templatePrefix is not None:
res.append(self.templatePrefix.get_id_except_requires_clause_in_last(version))
requiresClauseInLast = self.templatePrefix.get_requires_clause_in_last()
else:
requiresClauseInLast = None

if requiresClauseInLast or self.trailingRequiresClause:
if version < 4:
raise NoOldIdError()
res.append('IQ')
if self.requiresClause and self.trailingRequiresClause:
if requiresClauseInLast and self.trailingRequiresClause:
# make a conjunction of them
res.append('aa')
if self.requiresClause:
res.append(self.requiresClause.expr.get_id(version))
if requiresClauseInLast:
res.append(requiresClauseInLast.expr.get_id(version))
if self.trailingRequiresClause:
res.append(self.trailingRequiresClause.expr.get_id(version))
res.append('E')
Expand All @@ -3978,9 +4021,6 @@ def _stringify(self, transform: StringifyTransform) -> str:
res.append(' ')
if self.templatePrefix:
res.append(transform(self.templatePrefix))
if self.requiresClause:
res.append(transform(self.requiresClause))
res.append(' ')
res.append(transform(self.declaration))
if self.trailingRequiresClause:
res.append(' ')
Expand All @@ -4005,11 +4045,6 @@ def describe_signature(self, signode: desc_signature, mode: str,
self.templatePrefix.describe_signature(signode, mode, env,
symbol=self.symbol,
lineSpec=options.get('tparam-line-spec'))
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
signode.append(reqNode)
self.requiresClause.describe_signature(reqNode, 'markType', env, self.symbol)
signode += mainDeclNode
if self.visibility and self.visibility != "public":
mainDeclNode += addnodes.desc_sig_keyword(self.visibility, self.visibility)
Expand Down Expand Up @@ -4192,7 +4227,7 @@ def _add_template_and_function_params(self) -> None:
continue
# only add a declaration if we our self are from a declaration
if self.declaration:
decl = ASTDeclaration('templateParam', None, None, None, None, tp, None)
decl = ASTDeclaration(objectType='templateParam', declaration=tp)
else:
decl = None
nne = ASTNestedNameElement(tp.get_identifier(), None)
Expand All @@ -4207,7 +4242,7 @@ def _add_template_and_function_params(self) -> None:
if nn is None:
continue
# (comparing to the template params: we have checked that we are a declaration)
decl = ASTDeclaration('functionParam', None, None, None, None, fp, None)
decl = ASTDeclaration(objectType='functionParam', declaration=fp)
assert not nn.rooted
assert len(nn.names) == 1
self._add_symbols(nn, [], decl, self.docname, self.line)
Expand Down Expand Up @@ -6504,7 +6539,14 @@ def _parse_type(self, named: Union[bool, str], outer: str = None) -> ASTType:
declSpecs = self._parse_decl_specs(outer=outer, typed=False)
decl = self._parse_declarator(named=True, paramMode=outer,
typed=False)
self.assert_end(allowSemicolon=True)
mustEnd = True
if outer == 'function':
# Allow trailing requires on functions.
self.skip_ws()
if re.compile(r'requires\b').match(self.definition, self.pos):
mustEnd = False
if mustEnd:
self.assert_end(allowSemicolon=True)
except DefinitionError as exUntyped:
if outer == 'type':
desc = "If just a name"
Expand Down Expand Up @@ -6761,7 +6803,8 @@ def _parse_template_parameter_list(self) -> ASTTemplateParams:
err = eParam
self.skip_ws()
if self.skip_string('>'):
return ASTTemplateParams(templateParams)
requiresClause = self._parse_requires_clause()
return ASTTemplateParams(templateParams, requiresClause)
elif self.skip_string(','):
continue
else:
Expand Down Expand Up @@ -6883,6 +6926,8 @@ def _parse_template_declaration_prefix(self, objectType: str
return ASTTemplateDeclarationPrefix(None)
else:
raise e
if objectType == 'concept' and params.requiresClause is not None:
self.fail('requires-clause not allowed for concept')
else:
params = self._parse_template_introduction()
if not params:
Expand Down Expand Up @@ -6931,7 +6976,7 @@ def _check_template_consistency(self, nestedName: ASTNestedName,

newTemplates: List[Union[ASTTemplateParams, ASTTemplateIntroduction]] = []
for _i in range(numExtra):
newTemplates.append(ASTTemplateParams([]))
newTemplates.append(ASTTemplateParams([], requiresClause=None))
if templatePrefix and not isMemberInstantiation:
newTemplates.extend(templatePrefix.templates)
templatePrefix = ASTTemplateDeclarationPrefix(newTemplates)
Expand All @@ -6947,18 +6992,15 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
raise Exception('Internal error, unknown directiveType "%s".' % directiveType)
visibility = None
templatePrefix = None
requiresClause = None
trailingRequiresClause = None
declaration: Any = None

self.skip_ws()
if self.match(_visibility_re):
visibility = self.matched_text

if objectType in ('type', 'concept', 'member', 'function', 'class'):
if objectType in ('type', 'concept', 'member', 'function', 'class', 'union'):
templatePrefix = self._parse_template_declaration_prefix(objectType)
if objectType == 'function' and templatePrefix is not None:
requiresClause = self._parse_requires_clause()

if objectType == 'type':
prevErrors = []
Expand All @@ -6984,8 +7026,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
declaration = self._parse_type_with_init(named=True, outer='member')
elif objectType == 'function':
declaration = self._parse_type(named=True, outer='function')
if templatePrefix is not None:
trailingRequiresClause = self._parse_requires_clause()
trailingRequiresClause = self._parse_requires_clause()
elif objectType == 'class':
declaration = self._parse_class()
elif objectType == 'union':
Expand All @@ -7003,7 +7044,7 @@ def parse_declaration(self, objectType: str, directiveType: str) -> ASTDeclarati
self.skip_ws()
semicolon = self.skip_string(';')
return ASTDeclaration(objectType, directiveType, visibility,
templatePrefix, requiresClause, declaration,
templatePrefix, declaration,
trailingRequiresClause, semicolon)

def parse_namespace_object(self) -> ASTNamespace:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_domain_cpp.py
Expand Up @@ -891,8 +891,33 @@ def test_domain_cpp_ast_requires_clauses():
{4: 'I0EIQaa1A1BE1fvv'})
check('function', 'template<typename T> requires A || B or C void f()',
{4: 'I0EIQoo1Aoo1B1CE1fvv'})
check('function', 'void f() requires A || B || C',
{4: 'IQoo1Aoo1B1CE1fv'})
check('function', 'Foo() requires A || B || C',
{4: 'IQoo1Aoo1B1CE3Foov'})
check('function', 'template<typename T> requires A && B || C and D void f()',
{4: 'I0EIQooaa1A1Baa1C1DE1fvv'})
check('function',
'template<typename T> requires R<T> ' +
'template<typename U> requires S<T> ' +
'void A<T>::f() requires B',
{4: 'I0EIQ1RI1TEEI0EIQaa1SI1TE1BEN1AI1TE1fEvv'})
check('function',
'template<template<typename T> requires R<T> typename X> ' +
'void f()',
{2: 'II0EIQ1RI1TEE0E1fv', 4: 'II0EIQ1RI1TEE0E1fvv'})
check('type',
'template<typename T> requires IsValid<T> {key}T = true_type',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='using')
check('class',
'template<typename T> requires IsValid<T> {key}T : Base',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='class')
check('union',
'template<typename T> requires IsValid<T> {key}T',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='union')
check('member',
'template<typename T> requires IsValid<T> int Val = 7',
{4: 'I0EIQ7IsValidI1TEE3Val'})


def test_domain_cpp_ast_template_args():
Expand Down

0 comments on commit e867201

Please sign in to comment.