Skip to content

Commit

Permalink
Use TypeAlias annotation and ensure typing_extensions is not a ru…
Browse files Browse the repository at this point in the history
…ntime dep
  • Loading branch information
bryanforbes committed Feb 15, 2024
1 parent 276bb46 commit 5ede989
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 57 deletions.
10 changes: 5 additions & 5 deletions asyncpg/cluster.py
Expand Up @@ -24,16 +24,16 @@
from asyncpg import serverversion
from asyncpg import exceptions

if sys.version_info < (3, 12):
from typing_extensions import Unpack
else:
from typing import Unpack

if typing.TYPE_CHECKING:
import _typeshed
from . import types
from . import connection

if sys.version_info < (3, 12):
from typing_extensions import Unpack
else:
from typing import Unpack


class _ConnectionSpec(typing.TypedDict):
host: str
Expand Down
14 changes: 8 additions & 6 deletions asyncpg/compat.py
Expand Up @@ -64,6 +64,14 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None:
else:
from asyncio import timeout as timeout # noqa: F401

if typing.TYPE_CHECKING:
if sys.version_info < (3, 10):
from typing_extensions import TypeAlias as TypeAlias
else:
from typing import TypeAlias as TypeAlias # noqa: F401
else:
TypeAlias = typing.NewType('TypeAlias', object)

if sys.version_info < (3, 9):
from typing import (
AsyncIterable as AsyncIterable,
Expand All @@ -74,17 +82,11 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None:
Generator as Generator,
Iterable as Iterable,
Iterator as Iterator,
List as list,
OrderedDict as OrderedDict,
Sequence as Sequence,
Sized as Sized,
Tuple as tuple,
)
else:
from builtins import ( # noqa: F401
list as list,
tuple as tuple,
)
from collections import ( # noqa: F401
deque as deque,
OrderedDict as OrderedDict,
Expand Down
57 changes: 22 additions & 35 deletions asyncpg/connect_utils.py
Expand Up @@ -49,37 +49,24 @@
'_AsyncProtocolT', bound='asyncio.protocols.Protocol'
)
_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record)
_ParsedSSLType = typing.Union[
ssl_module.SSLContext, typing.Literal[False]
]
_SSLStringValues = typing.Literal[

_ParsedSSLType: compat.TypeAlias = (
'ssl_module.SSLContext | typing.Literal[False]'
)
_SSLStringValues: compat.TypeAlias = typing.Literal[
'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full'
]
_TPTupleType = compat.tuple[
asyncio.WriteTransport,
_AsyncProtocolT
]
AddrType = typing.Union[
compat.tuple[str, int],
str
]
HostType = typing.Union[compat.list[str], compat.tuple[str, ...], str]
PasswordType = typing.Union[
str,
compat.Callable[[], str],
compat.Callable[[], compat.Awaitable[str]]
]
PortListType = typing.Union[
compat.list[typing.Union[int, str]],
compat.list[int],
compat.list[str],
]
PortType = typing.Union[
PortListType,
int,
str
]
SSLType = typing.Union[_ParsedSSLType, _SSLStringValues, bool]
_TPTupleType: compat.TypeAlias = (
'tuple[asyncio.WriteTransport, _AsyncProtocolT]'
)
AddrType: compat.TypeAlias = 'tuple[str, int] | str'
HostType: compat.TypeAlias = 'list[str] | tuple[str, ...] | str'
PasswordType: compat.TypeAlias = (
'str | compat.Callable[[], str] | compat.Callable[[], compat.Awaitable[str]]' # noqa: E501
)
PortListType: compat.TypeAlias = 'list[int | str] | list[int] | list[str]'
PortType: compat.TypeAlias = 'PortListType | int | str'
SSLType: compat.TypeAlias = '_ParsedSSLType | _SSLStringValues | bool'


class SSLMode(enum.IntEnum):
Expand Down Expand Up @@ -880,11 +867,11 @@ async def _create_ssl_connection(
*,
loop: asyncio.AbstractEventLoop,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: typing.Optional[bool] = False
ssl_is_advisory: bool | None = False
) -> _TPTupleType[typing.Any]:

tr, pr = typing.cast(
compat.tuple[asyncio.WriteTransport, TLSUpgradeProto],
'tuple[asyncio.WriteTransport, TLSUpgradeProto]',
await loop.create_connection(
lambda: TLSUpgradeProto(
loop, host, port, ssl_context, ssl_is_advisory
Expand Down Expand Up @@ -1008,7 +995,7 @@ async def __connect_addr(
# UNIX socket
connector = typing.cast(
compat.Coroutine[
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
typing.Any, None, '_TPTupleType[protocol.Protocol[_RecordT]]'
],
loop.create_unix_connection(proto_factory, addr)
)
Expand All @@ -1018,7 +1005,7 @@ async def __connect_addr(
# SSL connection
connector = typing.cast(
compat.Coroutine[
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
typing.Any, None, '_TPTupleType[protocol.Protocol[_RecordT]]'
],
loop.create_connection(proto_factory, *addr, ssl=params.ssl)
)
Expand All @@ -1030,7 +1017,7 @@ async def __connect_addr(
else:
connector = typing.cast(
compat.Coroutine[
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
typing.Any, None, '_TPTupleType[protocol.Protocol[_RecordT]]'
],
loop.create_connection(proto_factory, *addr)
)
Expand Down Expand Up @@ -1237,7 +1224,7 @@ async def _cancel(

if isinstance(addr, str):
tr, pr = typing.cast(
_TPTupleType[_CancelProto],
'_TPTupleType[_CancelProto]',
await loop.create_unix_connection(proto_factory, addr)
)
else:
Expand Down
24 changes: 13 additions & 11 deletions asyncpg/connection.py
Expand Up @@ -57,18 +57,20 @@
_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record)
_P = ParamSpec('_P')

_WriterType = compat.Callable[
_WriterType: compat.TypeAlias = compat.Callable[
[bytes], compat.Coroutine[typing.Any, typing.Any, None]
]
_OutputType = typing.Union[
'os.PathLike[typing.Any]', typing.BinaryIO, _WriterType
]
_CopyFormat = typing.Literal['text', 'csv', 'binary']
_SourceType = typing.Union[
'os.PathLike[typing.Any]', typing.BinaryIO, compat.AsyncIterable[bytes]
]
_RecordsType = compat.list[_RecordT]
_RecordsTupleType = compat.tuple[_RecordsType[_RecordT], bytes, bool]
_OutputType: compat.TypeAlias = (
'os.PathLike[typing.Any] | typing.BinaryIO | _WriterType'
)
_CopyFormat: compat.TypeAlias = typing.Literal['text', 'csv', 'binary']
_SourceType: compat.TypeAlias = (
'os.PathLike[typing.Any] | typing.BinaryIO | compat.AsyncIterable[bytes]'
)
_RecordsType: compat.TypeAlias = 'list[_RecordT]'
_RecordsTupleType: compat.TypeAlias = (
'tuple[_RecordsType[_RecordT], bytes, bool]'
)


class Listener(typing.Protocol):
Expand Down Expand Up @@ -3104,7 +3106,7 @@ async def connect(
)


_StatementCacheKey = compat.tuple[str, 'type[_RecordT]', bool]
_StatementCacheKey: compat.TypeAlias = 'tuple[str, type[_RecordT], bool]'


class _StatementCacheEntry(typing.Generic[_RecordT]):
Expand Down

0 comments on commit 5ede989

Please sign in to comment.