diff --git a/src/attr/__init__.pyi b/src/attr/__init__.pyi index 126db1ac0..600dd613a 100644 --- a/src/attr/__init__.pyi +++ b/src/attr/__init__.pyi @@ -8,6 +8,7 @@ from typing import ( List, Mapping, Optional, + Protocol, Sequence, Tuple, Type, @@ -26,8 +27,6 @@ from ._cmp import cmp_using as cmp_using from ._version_info import VersionInfo from ._typing_compat import AttrsInstance_ -AttrsInstance = AttrsInstance_ - if sys.version_info >= (3, 10): from typing import TypeGuard else: @@ -65,6 +64,10 @@ _FieldTransformer = Callable[ # _ValidatorType from working when passed in a list or tuple. _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] +# We subclass this here to keep the protocol's qualified name clean. +class AttrsInstance(AttrsInstance_, Protocol): + pass + # _make -- NOTHING: object diff --git a/src/attr/_typing_compat.pyi b/src/attr/_typing_compat.pyi index e66ce0fd7..fb6c8ac48 100644 --- a/src/attr/_typing_compat.pyi +++ b/src/attr/_typing_compat.pyi @@ -2,11 +2,11 @@ from typing import Any, ClassVar, Protocol MYPY = False -# A protocol to be able to statically accept an attrs class. -class AttrsInstance(Protocol): - __attrs_attrs__: ClassVar[Any] - if MYPY: - AttrsInstance_ = AttrsInstance + # A protocol to be able to statically accept an attrs class. + class AttrsInstance_(Protocol): + __attrs_attrs__: ClassVar[Any] + else: - AttrsInstance_ = Any + class AttrsInstance_(Protocol): + pass diff --git a/tests/test_pyright.py b/tests/test_pyright.py index 7dc8ee13f..9b7945117 100644 --- a/tests/test_pyright.py +++ b/tests/test_pyright.py @@ -91,6 +91,7 @@ def test_pyright_attrsinstance_is_any(tmp_path): """\ import attrs +foo: attrs.AttrsInstance = object() # We can assign any old object to `AttrsInstance`. reveal_type(attrs.AttrsInstance) """ ) @@ -99,7 +100,7 @@ def test_pyright_attrsinstance_is_any(tmp_path): expected_diagnostics = { PyrightDiagnostic( severity="information", - message='Type of "attrs.AttrsInstance" is "Any"', + message='Type of "attrs.AttrsInstance" is "Type[AttrsInstance]"', ), } assert diagnostics == expected_diagnostics