From d42432bff29c5fe7c42926a24bd1272dea39af90 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Mon, 4 Mar 2024 18:37:43 -0600 Subject: [PATCH] Add initial typings (#1127) * Added typings to miscellaneous files * Added unit test to check codebase with mypy * Updated release workflow and build to account for annotations * Updated manifest to include stub files --- .flake8 | 4 +- .github/workflows/release.yml | 2 +- .gitignore | 2 + MANIFEST.in | 2 +- asyncpg/__init__.py | 7 +- asyncpg/_asyncio_compat.py | 13 +- asyncpg/_version.py | 6 +- asyncpg/compat.py | 24 ++- asyncpg/introspection.py | 22 ++- asyncpg/protocol/__init__.py | 2 + asyncpg/protocol/protocol.pyi | 300 ++++++++++++++++++++++++++++++++++ asyncpg/serverversion.py | 24 ++- asyncpg/types.py | 102 ++++++++---- pyproject.toml | 27 ++- setup.py | 2 +- tests/test__sourcecode.py | 33 +++- 16 files changed, 512 insertions(+), 60 deletions(-) create mode 100644 asyncpg/protocol/protocol.pyi diff --git a/.flake8 b/.flake8 index decf40da..d4e76b7a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,5 @@ [flake8] +select = C90,E,F,W,Y0 ignore = E402,E731,W503,W504,E252 -exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox +exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox +per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index eef0799e..450f471e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,7 +22,7 @@ jobs: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} version_file: asyncpg/_version.py version_line_pattern: | - __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) + __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Stop if not approved if: steps.checkver.outputs.approved != 'true' diff --git a/.gitignore b/.gitignore index 21286094..53c0daa1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ docs/_build /.eggs /.vscode /.mypy_cache +/.venv* +/.tox diff --git a/MANIFEST.in b/MANIFEST.in index 2389f6fa..3eac0565 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ recursive-include docs *.py *.rst Makefile *.css recursive-include examples *.py recursive-include tests *.py *.pem -recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h +recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h include LICENSE README.rst Makefile performance.png .flake8 diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index e8cd11eb..e8811a9d 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations from .connection import connect, Connection # NOQA from .exceptions import * # NOQA @@ -14,6 +15,10 @@ from ._version import __version__ # NOQA +from . import exceptions -__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection') + +__all__: tuple[str, ...] = ( + 'connect', 'create_pool', 'Pool', 'Record', 'Connection' +) __all__ += exceptions.__all__ # NOQA diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py index ad7dfd8c..a211d0a9 100644 --- a/asyncpg/_asyncio_compat.py +++ b/asyncpg/_asyncio_compat.py @@ -4,18 +4,25 @@ # # SPDX-License-Identifier: PSF-2.0 +from __future__ import annotations import asyncio import functools import sys +import typing + +if typing.TYPE_CHECKING: + from . import compat if sys.version_info < (3, 11): from async_timeout import timeout as timeout_ctx else: from asyncio import timeout as timeout_ctx +_T = typing.TypeVar('_T') + -async def wait_for(fut, timeout): +async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T: """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -65,7 +72,7 @@ async def wait_for(fut, timeout): return await fut -async def _cancel_and_wait(fut): +async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() @@ -82,6 +89,6 @@ async def _cancel_and_wait(fut): fut.remove_done_callback(cb) -def _release_waiter(waiter, *args): +def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None: if not waiter.done(): waiter.set_result(None) diff --git a/asyncpg/_version.py b/asyncpg/_version.py index 67fd67ab..383fe4d2 100644 --- a/asyncpg/_version.py +++ b/asyncpg/_version.py @@ -10,4 +10,8 @@ # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. -__version__ = '0.30.0.dev0' +from __future__ import annotations + +import typing + +__version__: typing.Final = '0.30.0.dev0' diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 3eec9eb7..435b4c48 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -4,22 +4,25 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import pathlib import platform import typing import sys +if typing.TYPE_CHECKING: + import asyncio -SYSTEM = platform.uname().system +SYSTEM: typing.Final = platform.uname().system -if SYSTEM == 'Windows': +if sys.platform == 'win32': import ctypes.wintypes - CSIDL_APPDATA = 0x001a + CSIDL_APPDATA: typing.Final = 0x001a - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. @@ -31,14 +34,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]: return pathlib.Path(buf.value) / 'postgresql' else: - def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + def get_pg_home_directory() -> pathlib.Path | None: try: return pathlib.Path.home() except (RuntimeError, KeyError): return None -async def wait_closed(stream): +async def wait_closed(stream: asyncio.StreamWriter) -> None: # Not all asyncio versions have StreamWriter.wait_closed(). if hasattr(stream, 'wait_closed'): try: @@ -59,3 +62,12 @@ async def wait_closed(stream): from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 else: from asyncio import timeout as timeout # noqa: F401 + +if sys.version_info < (3, 9): + from typing import ( # noqa: F401 + Awaitable as Awaitable, + ) +else: + from collections.abc import ( # noqa: F401 + Awaitable as Awaitable, + ) diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index 6c2caf03..641cf700 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -4,8 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -_TYPEINFO_13 = '''\ +import typing + +if typing.TYPE_CHECKING: + from . import protocol + +_TYPEINFO_13: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -124,7 +130,7 @@ '''.format(typeinfo=_TYPEINFO_13) -_TYPEINFO = '''\ +_TYPEINFO: typing.Final = '''\ ( SELECT t.oid AS oid, @@ -248,7 +254,7 @@ '''.format(typeinfo=_TYPEINFO) -TYPE_BY_NAME = '''\ +TYPE_BY_NAME: typing.Final = '''\ SELECT t.oid, t.typelem AS elemtype, @@ -277,16 +283,16 @@ SCALAR_TYPE_KINDS = (b'b', b'd', b'e') -def is_scalar_type(typeinfo) -> bool: +def is_scalar_type(typeinfo: protocol.Record) -> bool: return ( typeinfo['kind'] in SCALAR_TYPE_KINDS and not typeinfo['elemtype'] ) -def is_domain_type(typeinfo) -> bool: - return typeinfo['kind'] == b'd' +def is_domain_type(typeinfo: protocol.Record) -> bool: + return typeinfo['kind'] == b'd' # type: ignore[no-any-return] -def is_composite_type(typeinfo) -> bool: - return typeinfo['kind'] == b'c' +def is_composite_type(typeinfo: protocol.Record) -> bool: + return typeinfo['kind'] == b'c' # type: ignore[no-any-return] diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py index 8b3e06a0..af9287bd 100644 --- a/asyncpg/protocol/__init__.py +++ b/asyncpg/protocol/__init__.py @@ -6,4 +6,6 @@ # flake8: NOQA +from __future__ import annotations + from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/asyncpg/protocol/protocol.pyi b/asyncpg/protocol/protocol.pyi new file mode 100644 index 00000000..f85c5b6d --- /dev/null +++ b/asyncpg/protocol/protocol.pyi @@ -0,0 +1,300 @@ +import asyncio +import asyncio.protocols +import hmac +from codecs import CodecInfo +from collections.abc import Callable, Iterable, Iterator, Sequence +from hashlib import md5, sha256 +from typing import ( + Any, + ClassVar, + Final, + Generic, + Literal, + NewType, + TypeVar, + final, + overload, +) +from typing_extensions import TypeAlias + +import asyncpg.pgproto.pgproto + +from ..connect_utils import _ConnectionParameters +from ..pgproto.pgproto import WriteBuffer +from ..types import Attribute, Type + +_T = TypeVar('_T') +_Record = TypeVar('_Record', bound=Record) +_OtherRecord = TypeVar('_OtherRecord', bound=Record) +_PreparedStatementState = TypeVar( + '_PreparedStatementState', bound=PreparedStatementState[Any] +) + +_NoTimeoutType = NewType('_NoTimeoutType', object) +_TimeoutType: TypeAlias = float | None | _NoTimeoutType + +BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]] +BUILTIN_TYPE_OID_MAP: Final[dict[int, str]] +NO_TIMEOUT: Final[_NoTimeoutType] + +hashlib_md5 = md5 + +@final +class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext): + __pyx_vtable__: Any + def __init__(self, conn_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typeinfos: Iterable[object], + typekind: str, + encoder: Callable[[Any], Any], + decoder: Callable[[Any], Any], + format: object, + ) -> Any: ... + def clear_type_cache(self) -> None: ... + def get_data_codec( + self, oid: int, format: object = ..., ignore_custom_codec: bool = ... + ) -> Any: ... + def get_text_codec(self) -> CodecInfo: ... + def register_data_types(self, types: Iterable[object]) -> None: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> None: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __getattr__(self, name: str) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class PreparedStatementState(Generic[_Record]): + closed: bool + prepared: bool + name: str + query: str + refs: int + record_class: type[_Record] + ignore_custom_codec: bool + __pyx_vtable__: Any + def __init__( + self, + name: str, + query: str, + protocol: BaseProtocol[Any], + record_class: type[_Record], + ignore_custom_codec: bool, + ) -> None: ... + def _get_parameters(self) -> tuple[Type, ...]: ... + def _get_attributes(self) -> tuple[Attribute, ...]: ... + def _init_types(self) -> set[int]: ... + def _init_codecs(self) -> None: ... + def attach(self) -> None: ... + def detach(self) -> None: ... + def mark_closed(self) -> None: ... + def mark_unprepared(self) -> None: ... + def __reduce__(self) -> Any: ... + +class CoreProtocol: + backend_pid: Any + backend_secret: Any + __pyx_vtable__: Any + def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ... + def is_in_transaction(self) -> bool: ... + def __reduce__(self) -> Any: ... + +class BaseProtocol(CoreProtocol, Generic[_Record]): + queries_count: Any + is_ssl: bool + __pyx_vtable__: Any + def __init__( + self, + addr: object, + connected_fut: object, + con_params: _ConnectionParameters, + record_class: type[_Record], + loop: object, + ) -> None: ... + def set_connection(self, connection: object) -> None: ... + def get_server_pid(self, *args: object, **kwargs: object) -> int: ... + def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ... + def get_record_class(self) -> type[_Record]: ... + def abort(self) -> None: ... + async def bind( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + timeout: _TimeoutType, + ) -> Any: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[False], + timeout: _TimeoutType, + ) -> list[_OtherRecord]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: Literal[True], + timeout: _TimeoutType, + ) -> tuple[list[_OtherRecord], bytes, bool]: ... + @overload + async def bind_execute( + self, + state: PreparedStatementState[_OtherRecord], + args: Sequence[object], + portal_name: str, + limit: int, + return_extra: bool, + timeout: _TimeoutType, + ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ... + async def bind_execute_many( + self, + state: PreparedStatementState[_OtherRecord], + args: Iterable[Sequence[object]], + portal_name: str, + timeout: _TimeoutType, + ) -> None: ... + async def close(self, timeout: _TimeoutType) -> None: ... + def _get_timeout(self, timeout: _TimeoutType) -> float | None: ... + def _is_cancelling(self) -> bool: ... + async def _wait_for_cancellation(self) -> None: ... + async def close_statement( + self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType + ) -> Any: ... + async def copy_in(self, *args: object, **kwargs: object) -> str: ... + async def copy_out(self, *args: object, **kwargs: object) -> str: ... + async def execute(self, *args: object, **kwargs: object) -> Any: ... + def is_closed(self, *args: object, **kwargs: object) -> Any: ... + def is_connected(self, *args: object, **kwargs: object) -> Any: ... + def data_received(self, data: object) -> None: ... + def connection_made(self, transport: object) -> None: ... + def connection_lost(self, exc: Exception | None) -> None: ... + def pause_writing(self, *args: object, **kwargs: object) -> Any: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: _PreparedStatementState, + ignore_custom_codec: bool = ..., + record_class: None, + ) -> _PreparedStatementState: ... + @overload + async def prepare( + self, + stmt_name: str, + query: str, + timeout: float | None = ..., + *, + state: None = ..., + ignore_custom_codec: bool = ..., + record_class: type[_OtherRecord], + ) -> PreparedStatementState[_OtherRecord]: ... + async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ... + async def query(self, *args: object, **kwargs: object) -> str: ... + def resume_writing(self, *args: object, **kwargs: object) -> Any: ... + def __reduce__(self) -> Any: ... + +@final +class Codec: + __pyx_vtable__: Any + def __reduce__(self) -> Any: ... + +class DataCodecConfig: + __pyx_vtable__: Any + def __init__(self, cache_key: object) -> None: ... + def add_python_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + typeinfos: Iterable[object], + encoder: Callable[[ConnectionSettings, WriteBuffer, object], object], + decoder: Callable[..., object], + format: object, + xformat: object, + ) -> Any: ... + def add_types(self, types: Iterable[object]) -> Any: ... + def clear_type_cache(self) -> None: ... + def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ... + def remove_python_codec( + self, typeoid: int, typename: str, typeschema: str + ) -> Any: ... + def set_builtin_type_codec( + self, + typeoid: int, + typename: str, + typeschema: str, + typekind: str, + alias_to: str, + format: object = ..., + ) -> Any: ... + def __reduce__(self) -> Any: ... + +class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ... + +class Record: + @overload + def get(self, key: str) -> Any | None: ... + @overload + def get(self, key: str, default: _T) -> Any | _T: ... + def items(self) -> Iterator[tuple[str, Any]]: ... + def keys(self) -> Iterator[str]: ... + def values(self) -> Iterator[Any]: ... + @overload + def __getitem__(self, index: str) -> Any: ... + @overload + def __getitem__(self, index: int) -> Any: ... + @overload + def __getitem__(self, index: slice) -> tuple[Any, ...]: ... + def __iter__(self) -> Iterator[Any]: ... + def __contains__(self, x: object) -> bool: ... + def __len__(self) -> int: ... + +class Timer: + def __init__(self, budget: float | None) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, et: object, e: object, tb: object) -> None: ... + def get_remaining_budget(self) -> float: ... + def has_budget_greater_than(self, amount: float) -> bool: ... + +@final +class SCRAMAuthentication: + AUTHENTICATION_METHODS: ClassVar[list[str]] + DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int] + DIGEST = sha256 + REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]] + REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]] + SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]] + authentication_method: bytes + authorization_message: bytes | None + client_channel_binding: bytes + client_first_message_bare: bytes | None + client_nonce: bytes | None + client_proof: bytes | None + password_salt: bytes | None + password_iterations: int + server_first_message: bytes | None + server_key: hmac.HMAC | None + server_nonce: bytes | None diff --git a/asyncpg/serverversion.py b/asyncpg/serverversion.py index 31568a2e..ee9647b4 100644 --- a/asyncpg/serverversion.py +++ b/asyncpg/serverversion.py @@ -4,12 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import re +import typing from .types import ServerVersion -version_regex = re.compile( +version_regex: typing.Final = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" @@ -19,7 +21,15 @@ ) -def split_server_version_string(version_string): +class _VersionDict(typing.TypedDict): + major: int + minor: int | None + micro: int | None + releaselevel: str | None + serial: int | None + + +def split_server_version_string(version_string: str) -> ServerVersion: version_match = version_regex.search(version_string) if version_match is None: @@ -28,17 +38,17 @@ def split_server_version_string(version_string): f'version from "{version_string}"' ) - version = version_match.groupdict() + version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501 for ver_key, ver_value in version.items(): # Cast all possible versions parts to int try: - version[ver_key] = int(ver_value) + version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501 except (TypeError, ValueError): pass - if version.get("major") < 10: + if version["major"] < 10: return ServerVersion( - version.get("major"), + version["major"], version.get("minor") or 0, version.get("micro") or 0, version.get("releaselevel") or "final", @@ -52,7 +62,7 @@ def split_server_version_string(version_string): # want to keep that behaviour consistent, i.e not fail # a major version check due to a bugfix release. return ServerVersion( - version.get("major"), + version["major"], 0, version.get("minor") or 0, version.get("releaselevel") or "final", diff --git a/asyncpg/types.py b/asyncpg/types.py index bd5813fc..7a24e24c 100644 --- a/asyncpg/types.py +++ b/asyncpg/types.py @@ -4,14 +4,18 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations -import collections +import typing from asyncpg.pgproto.types import ( BitString, Point, Path, Polygon, Box, Line, LineSegment, Circle, ) +if typing.TYPE_CHECKING: + from typing_extensions import Self + __all__ = ( 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', @@ -19,7 +23,13 @@ ) -Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema']) +class Type(typing.NamedTuple): + oid: int + name: str + kind: str + schema: str + + Type.__doc__ = 'Database data type.' Type.oid.__doc__ = 'OID of the type.' Type.name.__doc__ = 'Type name. For example "int2".' @@ -28,25 +38,61 @@ Type.schema.__doc__ = 'Name of the database schema that defines the type.' -Attribute = collections.namedtuple('Attribute', ['name', 'type']) +class Attribute(typing.NamedTuple): + name: str + type: Type + + Attribute.__doc__ = 'Database relation attribute.' Attribute.name.__doc__ = 'Attribute name.' Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' -ServerVersion = collections.namedtuple( - 'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial']) +class ServerVersion(typing.NamedTuple): + major: int + minor: int + micro: int + releaselevel: str + serial: int + + ServerVersion.__doc__ = 'PostgreSQL server version tuple.' -class Range: - """Immutable representation of PostgreSQL `range` type.""" +class _RangeValue(typing.Protocol): + def __eq__(self, __value: object) -> bool: + ... + + def __lt__(self, __other: _RangeValue) -> bool: + ... + + def __gt__(self, __other: _RangeValue) -> bool: + ... + - __slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty' +_RV = typing.TypeVar('_RV', bound=_RangeValue) + + +class Range(typing.Generic[_RV]): + """Immutable representation of PostgreSQL `range` type.""" - def __init__(self, lower=None, upper=None, *, - lower_inc=True, upper_inc=False, - empty=False): + __slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty') + + _lower: _RV | None + _upper: _RV | None + _lower_inc: bool + _upper_inc: bool + _empty: bool + + def __init__( + self, + lower: _RV | None = None, + upper: _RV | None = None, + *, + lower_inc: bool = True, + upper_inc: bool = False, + empty: bool = False + ) -> None: self._empty = empty if empty: self._lower = self._upper = None @@ -58,34 +104,34 @@ def __init__(self, lower=None, upper=None, *, self._upper_inc = upper is not None and upper_inc @property - def lower(self): + def lower(self) -> _RV | None: return self._lower @property - def lower_inc(self): + def lower_inc(self) -> bool: return self._lower_inc @property - def lower_inf(self): + def lower_inf(self) -> bool: return self._lower is None and not self._empty @property - def upper(self): + def upper(self) -> _RV | None: return self._upper @property - def upper_inc(self): + def upper_inc(self) -> bool: return self._upper_inc @property - def upper_inf(self): + def upper_inf(self) -> bool: return self._upper is None and not self._empty @property - def isempty(self): + def isempty(self) -> bool: return self._empty - def _issubset_lower(self, other): + def _issubset_lower(self, other: Self) -> bool: if other._lower is None: return True if self._lower is None: @@ -96,7 +142,7 @@ def _issubset_lower(self, other): and (other._lower_inc or not self._lower_inc) ) - def _issubset_upper(self, other): + def _issubset_upper(self, other: Self) -> bool: if other._upper is None: return True if self._upper is None: @@ -107,7 +153,7 @@ def _issubset_upper(self, other): and (other._upper_inc or not self._upper_inc) ) - def issubset(self, other): + def issubset(self, other: Self) -> bool: if self._empty: return True if other._empty: @@ -115,13 +161,13 @@ def issubset(self, other): return self._issubset_lower(other) and self._issubset_upper(other) - def issuperset(self, other): + def issuperset(self, other: Self) -> bool: return other.issubset(self) - def __bool__(self): + def __bool__(self) -> bool: return not self._empty - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Range): return NotImplemented @@ -132,14 +178,14 @@ def __eq__(self, other): self._upper_inc, self._empty ) == ( - other._lower, - other._upper, + other._lower, # pyright: ignore [reportUnknownMemberType] + other._upper, # pyright: ignore [reportUnknownMemberType] other._lower_inc, other._upper_inc, other._empty ) - def __hash__(self): + def __hash__(self) -> int: return hash(( self._lower, self._upper, @@ -148,7 +194,7 @@ def __hash__(self): self._empty )) - def __repr__(self): + def __repr__(self) -> str: if self._empty: desc = 'empty' else: diff --git a/pyproject.toml b/pyproject.toml index 8209d838..0019dadc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Topic :: Database :: Front-Ends", ] dependencies = [ - 'async_timeout>=4.0.3; python_version < "3.12.0"' + 'async_timeout>=4.0.3; python_version < "3.12.0"', ] [project.urls] @@ -40,9 +40,11 @@ gssapi = [ ] test = [ 'flake8~=6.1', + 'flake8-pyi~=24.1.0', 'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"', 'gssapi; platform_system == "Linux"', 'k5test; platform_system == "Linux"', + 'mypy~=1.8.0', ] docs = [ 'Sphinx~=5.3.0', @@ -107,3 +109,26 @@ exclude_lines = [ "if __name__ == .__main__.", ] show_missing = true + +[tool.mypy] +incremental = true +strict = true +implicit_reexport = true + +[[tool.mypy.overrides]] +module = [ + "asyncpg._testbase", + "asyncpg._testbase.*", + "asyncpg.cluster", + "asyncpg.connect_utils", + "asyncpg.connection", + "asyncpg.connresource", + "asyncpg.cursor", + "asyncpg.exceptions", + "asyncpg.exceptions.*", + "asyncpg.pool", + "asyncpg.prepared_stmt", + "asyncpg.transaction", + "asyncpg.utils", +] +ignore_errors = true diff --git a/setup.py b/setup.py index c4d42d82..f7c3c471 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ with open(str(_ROOT / 'asyncpg' / '_version.py')) as f: for line in f: - if line.startswith('__version__ ='): + if line.startswith('__version__: typing.Final ='): _, _, version = line.partition('=') VERSION = version.strip(" \n'\"") break diff --git a/tests/test__sourcecode.py b/tests/test__sourcecode.py index 28ffdea7..b19044d4 100644 --- a/tests/test__sourcecode.py +++ b/tests/test__sourcecode.py @@ -14,7 +14,7 @@ def find_root(): return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -class TestFlake8(unittest.TestCase): +class TestCodeQuality(unittest.TestCase): def test_flake8(self): try: @@ -38,3 +38,34 @@ def test_flake8(self): output = ex.output.decode() raise AssertionError( 'flake8 validation failed:\n{}'.format(output)) from None + + def test_mypy(self): + try: + import mypy # NoQA + except ImportError: + raise unittest.SkipTest('mypy module is missing') + + root_path = find_root() + config_path = os.path.join(root_path, 'pyproject.toml') + if not os.path.exists(config_path): + raise RuntimeError('could not locate mypy.ini file') + + try: + subprocess.run( + [ + sys.executable, + '-m', + 'mypy', + '--config-file', + config_path, + 'asyncpg' + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=root_path + ) + except subprocess.CalledProcessError as ex: + output = ex.output.decode() + raise AssertionError( + 'mypy validation failed:\n{}'.format(output)) from None