diff --git a/changelog.d/3391.change.rst b/changelog.d/3391.change.rst new file mode 100644 index 0000000000..41cfea3355 --- /dev/null +++ b/changelog.d/3391.change.rst @@ -0,0 +1 @@ +Updated ``attr:`` to also extract simple constants with type annotations -- by :user:`karlotness` diff --git a/setuptools/config/expand.py b/setuptools/config/expand.py index be987df5b4..ed7564047a 100644 --- a/setuptools/config/expand.py +++ b/setuptools/config/expand.py @@ -66,25 +66,21 @@ def __init__(self, name: str, spec: ModuleSpec): vars(self).update(locals()) del self.self + def _find_assignments(self) -> Iterator[Tuple[ast.AST, ast.AST]]: + for statement in self.module.body: + if isinstance(statement, ast.Assign): + yield from ((target, statement.value) for target in statement.targets) + elif isinstance(statement, ast.AnnAssign) and statement.value: + yield (statement.target, statement.value) + def __getattr__(self, attr): """Attempt to load an attribute "statically", via :func:`ast.literal_eval`.""" try: - assignment_expressions = ( - statement - for statement in self.module.body - if isinstance(statement, ast.Assign) - ) - expressions_with_target = ( - (statement, target) - for statement in assignment_expressions - for target in statement.targets - ) - matching_values = ( - statement.value - for statement, target in expressions_with_target + return next( + ast.literal_eval(value) + for target, value in self._find_assignments() if isinstance(target, ast.Name) and target.id == attr ) - return next(ast.literal_eval(value) for value in matching_values) except Exception as e: raise AttributeError(f"{self.name} has no attribute {attr}") from e diff --git a/setuptools/tests/config/test_expand.py b/setuptools/tests/config/test_expand.py index 15053c8f24..523779a8ed 100644 --- a/setuptools/tests/config/test_expand.py +++ b/setuptools/tests/config/test_expand.py @@ -85,6 +85,22 @@ def test_read_attr(self, tmp_path, monkeypatch): values = expand.read_attr('lib.mod.VALUES', {'lib': 'pkg/sub'}, tmp_path) assert values['c'] == (0, 1, 1) + @pytest.mark.parametrize( + "example", + [ + "VERSION: str\nVERSION = '0.1.1'\nraise SystemExit(1)\n", + "VERSION: str = '0.1.1'\nraise SystemExit(1)\n", + ] + ) + def test_read_annotated_attr(self, tmp_path, example): + files = { + "pkg/__init__.py": "", + "pkg/sub/__init__.py": example, + } + write_files(files, tmp_path) + # Make sure this attribute can be read statically + assert expand.read_attr('pkg.sub.VERSION', root_dir=tmp_path) == '0.1.1' + def test_import_order(self, tmp_path): """ Sometimes the import machinery will import the parent package of a nested