Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial typings #1127

Merged
merged 1 commit into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion .flake8
@@ -1,3 +1,5 @@
[flake8]
select = C90,E,F,W,Y0
ignore = E402,E731,W503,W504,E252
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Expand Up @@ -22,7 +22,7 @@ jobs:
github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
version_file: asyncpg/_version.py
version_line_pattern: |
__version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
__version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])

- name: Stop if not approved
if: steps.checkver.outputs.approved != 'true'
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -34,3 +34,5 @@ docs/_build
/.eggs
/.vscode
/.mypy_cache
/.venv*
/.tox
2 changes: 1 addition & 1 deletion MANIFEST.in
@@ -1,5 +1,5 @@
recursive-include docs *.py *.rst Makefile *.css
recursive-include examples *.py
recursive-include tests *.py *.pem
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
include LICENSE README.rst Makefile performance.png .flake8
7 changes: 6 additions & 1 deletion asyncpg/__init__.py
Expand Up @@ -4,6 +4,7 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
Expand All @@ -14,6 +15,10 @@

from ._version import __version__ # NOQA

from . import exceptions

__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')

__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
__all__ += exceptions.__all__ # NOQA
13 changes: 10 additions & 3 deletions asyncpg/_asyncio_compat.py
Expand Up @@ -4,18 +4,25 @@
#
# SPDX-License-Identifier: PSF-2.0

from __future__ import annotations

import asyncio
import functools
import sys
import typing

if typing.TYPE_CHECKING:
from . import compat

if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx

_T = typing.TypeVar('_T')


async def wait_for(fut, timeout):
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
"""Wait for the single Future or coroutine to complete, with timeout.

Coroutine will be wrapped in Task.
Expand Down Expand Up @@ -65,7 +72,7 @@ async def wait_for(fut, timeout):
return await fut


async def _cancel_and_wait(fut):
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
"""Cancel the *fut* future or task and wait until it completes."""

loop = asyncio.get_running_loop()
Expand All @@ -82,6 +89,6 @@ async def _cancel_and_wait(fut):
fut.remove_done_callback(cb)


def _release_waiter(waiter, *args):
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
if not waiter.done():
waiter.set_result(None)
6 changes: 5 additions & 1 deletion asyncpg/_version.py
Expand Up @@ -10,4 +10,8 @@
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.

__version__ = '0.30.0.dev0'
from __future__ import annotations

import typing

__version__: typing.Final = '0.30.0.dev0'
24 changes: 18 additions & 6 deletions asyncpg/compat.py
Expand Up @@ -4,22 +4,25 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import pathlib
import platform
import typing
import sys

if typing.TYPE_CHECKING:
import asyncio

SYSTEM = platform.uname().system
SYSTEM: typing.Final = platform.uname().system


if SYSTEM == 'Windows':
if sys.platform == 'win32':
import ctypes.wintypes

CSIDL_APPDATA = 0x001a
CSIDL_APPDATA: typing.Final = 0x001a

def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
Expand All @@ -31,14 +34,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
return pathlib.Path(buf.value) / 'postgresql'

else:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None


async def wait_closed(stream):
async def wait_closed(stream: asyncio.StreamWriter) -> None:
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
Expand All @@ -59,3 +62,12 @@ async def wait_closed(stream):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401

if sys.version_info < (3, 9):
from typing import ( # noqa: F401
Awaitable as Awaitable,
)
else:
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)
22 changes: 14 additions & 8 deletions asyncpg/introspection.py
Expand Up @@ -4,8 +4,14 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

_TYPEINFO_13 = '''\
import typing

if typing.TYPE_CHECKING:
from . import protocol

_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
Expand Down Expand Up @@ -124,7 +130,7 @@
'''.format(typeinfo=_TYPEINFO_13)


_TYPEINFO = '''\
_TYPEINFO: typing.Final = '''\
(
SELECT
t.oid AS oid,
Expand Down Expand Up @@ -248,7 +254,7 @@
'''.format(typeinfo=_TYPEINFO)


TYPE_BY_NAME = '''\
TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
Expand Down Expand Up @@ -277,16 +283,16 @@
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')


def is_scalar_type(typeinfo) -> bool:
def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)


def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
def is_domain_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]


def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'
def is_composite_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
2 changes: 2 additions & 0 deletions asyncpg/protocol/__init__.py
Expand Up @@ -6,4 +6,6 @@

# flake8: NOQA

from __future__ import annotations

from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP