Skip to content

Commit

Permalink
Add initial typings (#1127)
Browse files Browse the repository at this point in the history
* Added typings to miscellaneous files
* Added unit test to check codebase with mypy
* Updated release workflow and build to account for annotations
* Updated manifest to include stub files
  • Loading branch information
bryanforbes committed Mar 5, 2024
1 parent 1d4e568 commit d42432b
Show file tree
Hide file tree
Showing 16 changed files with 512 additions and 60 deletions.
4 changes: 3 additions & 1 deletion .flake8
@@ -1,3 +1,5 @@
[flake8]
select = C90,E,F,W,Y0
ignore = E402,E731,W503,W504,E252
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Expand Up @@ -22,7 +22,7 @@ jobs:
github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
version_file: asyncpg/_version.py
version_line_pattern: |
__version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
__version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
- name: Stop if not approved
if: steps.checkver.outputs.approved != 'true'
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -34,3 +34,5 @@ docs/_build
/.eggs
/.vscode
/.mypy_cache
/.venv*
/.tox
2 changes: 1 addition & 1 deletion MANIFEST.in
@@ -1,5 +1,5 @@
recursive-include docs *.py *.rst Makefile *.css
recursive-include examples *.py
recursive-include tests *.py *.pem
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
include LICENSE README.rst Makefile performance.png .flake8
7 changes: 6 additions & 1 deletion asyncpg/__init__.py
Expand Up @@ -4,6 +4,7 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
Expand All @@ -14,6 +15,10 @@

from ._version import __version__ # NOQA

from . import exceptions

__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')

__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
__all__ += exceptions.__all__ # NOQA
13 changes: 10 additions & 3 deletions asyncpg/_asyncio_compat.py
Expand Up @@ -4,18 +4,25 @@
#
# SPDX-License-Identifier: PSF-2.0

from __future__ import annotations

import asyncio
import functools
import sys
import typing

if typing.TYPE_CHECKING:
from . import compat

if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx

_T = typing.TypeVar('_T')


async def wait_for(fut, timeout):
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
Expand Down Expand Up @@ -65,7 +72,7 @@ async def wait_for(fut, timeout):
return await fut


async def _cancel_and_wait(fut):
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
"""Cancel the *fut* future or task and wait until it completes."""

loop = asyncio.get_running_loop()
Expand All @@ -82,6 +89,6 @@ async def _cancel_and_wait(fut):
fut.remove_done_callback(cb)


def _release_waiter(waiter, *args):
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
if not waiter.done():
waiter.set_result(None)
6 changes: 5 additions & 1 deletion asyncpg/_version.py
Expand Up @@ -10,4 +10,8 @@
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.

__version__ = '0.30.0.dev0'
from __future__ import annotations

import typing

__version__: typing.Final = '0.30.0.dev0'
24 changes: 18 additions & 6 deletions asyncpg/compat.py
Expand Up @@ -4,22 +4,25 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import pathlib
import platform
import typing
import sys

if typing.TYPE_CHECKING:
import asyncio

SYSTEM = platform.uname().system
SYSTEM: typing.Final = platform.uname().system


if SYSTEM == 'Windows':
if sys.platform == 'win32':
import ctypes.wintypes

CSIDL_APPDATA = 0x001a
CSIDL_APPDATA: typing.Final = 0x001a

def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
Expand All @@ -31,14 +34,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
return pathlib.Path(buf.value) / 'postgresql'

else:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None


async def wait_closed(stream):
async def wait_closed(stream: asyncio.StreamWriter) -> None:
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
Expand All @@ -59,3 +62,12 @@ async def wait_closed(stream):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401

if sys.version_info < (3, 9):
from typing import ( # noqa: F401
Awaitable as Awaitable,
)
else:
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)
22 changes: 14 additions & 8 deletions asyncpg/introspection.py
Expand Up @@ -4,8 +4,14 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

_TYPEINFO_13 = '''\
import typing

if typing.TYPE_CHECKING:
from . import protocol

_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
Expand Down Expand Up @@ -124,7 +130,7 @@
'''.format(typeinfo=_TYPEINFO_13)


_TYPEINFO = '''\
_TYPEINFO: typing.Final = '''\
(
SELECT
t.oid AS oid,
Expand Down Expand Up @@ -248,7 +254,7 @@
'''.format(typeinfo=_TYPEINFO)


TYPE_BY_NAME = '''\
TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
Expand Down Expand Up @@ -277,16 +283,16 @@
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')


def is_scalar_type(typeinfo) -> bool:
def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)


def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
def is_domain_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]


def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'
def is_composite_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
2 changes: 2 additions & 0 deletions asyncpg/protocol/__init__.py
Expand Up @@ -6,4 +6,6 @@

# flake8: NOQA

from __future__ import annotations

from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

0 comments on commit d42432b

Please sign in to comment.