Skip to content

Commit

Permalink
Remove TypeAlias annotations and handle Literal properly
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Jul 28, 2022
1 parent 7528e6d commit fb6d8ef
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 42 deletions.
2 changes: 0 additions & 2 deletions asyncpg/compat.py
Expand Up @@ -13,13 +13,11 @@

if sys.version_info >= (3, 8):
from typing import (
Literal as Literal,
Protocol as Protocol,
TypedDict as TypedDict
)
else:
from typing_extensions import ( # noqa: F401
Literal as Literal,
Protocol as Protocol,
TypedDict as TypedDict
)
Expand Down
23 changes: 11 additions & 12 deletions asyncpg/connect_utils.py
Expand Up @@ -21,15 +21,14 @@
import sys
import time
import typing
import typing_extensions
import urllib.parse
import warnings

# Work around https://github.com/microsoft/pyright/issues/3012
if sys.version_info >= (3, 8):
from typing import Final
from typing import Final, Literal
else:
from typing_extensions import Final
from typing_extensions import Final, Literal

from . import compat
from . import exceptions
Expand All @@ -52,30 +51,30 @@
_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record)
_SSLModeT = typing.TypeVar('_SSLModeT', bound='SSLMode')

_TPTupleType: typing_extensions.TypeAlias = typing.Tuple[
_TPTupleType = typing.Tuple[
asyncio.WriteTransport,
_AsyncProtocolT
]
_SSLStringValues = compat.Literal[
_SSLStringValues = Literal[
'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full'
]
AddrType: typing_extensions.TypeAlias = typing.Union[
AddrType = typing.Union[
typing.Tuple[str, int],
str
]
_ParsedSSLType: typing_extensions.TypeAlias = typing.Union[
ssl_module.SSLContext, compat.Literal[False]
_ParsedSSLType = typing.Union[
ssl_module.SSLContext, Literal[False]
]
SSLType: typing_extensions.TypeAlias = typing.Union[
SSLType = typing.Union[
_ParsedSSLType, _SSLStringValues, bool
]
HostType: typing_extensions.TypeAlias = typing.Union[typing.List[str], str]
PortListType: typing_extensions.TypeAlias = typing.Union[
HostType = typing.Union[typing.List[str], str]
PortListType = typing.Union[
typing.List[typing.Union[int, str]],
typing.List[int],
typing.List[str],
]
PortType: typing_extensions.TypeAlias = typing.Union[
PortType = typing.Union[
PortListType,
int,
str
Expand Down
40 changes: 22 additions & 18 deletions asyncpg/connection.py
Expand Up @@ -6,7 +6,6 @@


import asyncio
import typing_extensions
import asyncpg
import collections
import collections.abc
Expand All @@ -33,40 +32,45 @@
from . import types
from . import utils

# Work around https://github.com/microsoft/pyright/issues/3012
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if typing.TYPE_CHECKING:
import io
from .protocol import protocol as _cprotocol
from . import pool_connection_proxy as _pool


_ConnectionT = typing.TypeVar('_ConnectionT', bound='Connection[typing.Any]')
_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record)
_OtherRecordT = typing.TypeVar('_OtherRecordT', bound=protocol.Record)

_Writer: typing_extensions.TypeAlias = typing.Callable[
_Writer = typing.Callable[
[bytes],
typing.Coroutine[typing.Any, typing.Any, None]
]
_RecordsType: typing_extensions.TypeAlias = typing.List[_RecordT]
_RecordsExtraType: typing_extensions.TypeAlias = typing.Tuple[
_RecordsType = typing.List[_RecordT]
_RecordsExtraType = typing.Tuple[
_RecordsType[_RecordT],
bytes,
bool
]

OutputType: typing_extensions.TypeAlias = typing.Union[
OutputType = typing.Union[
'os.PathLike[typing.Any]',
typing.BinaryIO,
_Writer
]
SourceType: typing_extensions.TypeAlias = typing.Union[
SourceType = typing.Union[
'os.PathLike[typing.Any]',
typing.BinaryIO,
typing.AsyncIterable[bytes]
]

CopyFormat = compat.Literal['text', 'csv', 'binary']
PasswordType: typing_extensions.TypeAlias = typing.Union[
CopyFormat = Literal['text', 'csv', 'binary']
PasswordType = typing.Union[
str,
typing.Callable[[], str],
typing.Callable[[], typing.Awaitable[str]]
Expand Down Expand Up @@ -2133,7 +2137,7 @@ async def _execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[False] = ...,
return_status: Literal[False] = ...,
ignore_custom_codec: bool = ...,
record_class: None = ...
) -> _RecordsType[_RecordT]:
Expand All @@ -2147,7 +2151,7 @@ async def _execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[False] = ...,
return_status: Literal[False] = ...,
ignore_custom_codec: bool = ...,
record_class: typing.Type[_OtherRecordT]
) -> _RecordsType[_OtherRecordT]:
Expand All @@ -2161,7 +2165,7 @@ async def _execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[False] = ...,
return_status: Literal[False] = ...,
ignore_custom_codec: bool = ...,
record_class: typing.Optional[typing.Type[_OtherRecordT]]
) -> typing.Union[_RecordsType[_RecordT], _RecordsType[_OtherRecordT]]:
Expand All @@ -2175,7 +2179,7 @@ async def _execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[True],
return_status: Literal[True],
ignore_custom_codec: bool = ...,
record_class: None = ...
) -> _RecordsExtraType[_RecordT]:
Expand All @@ -2189,7 +2193,7 @@ async def _execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[True],
return_status: Literal[True],
ignore_custom_codec: bool = ...,
record_class: typing.Type[_OtherRecordT]
) -> _RecordsExtraType[_OtherRecordT]:
Expand Down Expand Up @@ -2226,7 +2230,7 @@ async def __execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[False] = ...,
return_status: Literal[False] = ...,
ignore_custom_codec: bool = ...,
record_class: None = ...
) -> typing.Tuple[
Expand All @@ -2243,7 +2247,7 @@ async def __execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[False] = ...,
return_status: Literal[False] = ...,
ignore_custom_codec: bool = ...,
record_class: typing.Type[_OtherRecordT]
) -> typing.Tuple[
Expand All @@ -2260,7 +2264,7 @@ async def __execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[True],
return_status: Literal[True],
ignore_custom_codec: bool = ...,
record_class: None = ...
) -> typing.Tuple[
Expand All @@ -2277,7 +2281,7 @@ async def __execute(
limit: int,
timeout: typing.Optional[float],
*,
return_status: compat.Literal[True],
return_status: Literal[True],
ignore_custom_codec: bool = ...,
record_class: typing.Type[_OtherRecordT]
) -> typing.Tuple[
Expand Down
7 changes: 3 additions & 4 deletions asyncpg/protocol/protocol.pyi
Expand Up @@ -19,16 +19,15 @@ from typing import (
)

if sys.version_info >= (3, 10):
from typing import Final, TypeAlias, final
from typing import Final, Literal, TypeAlias, final
elif sys.version_info >= (3, 8):
from typing import Final, final
from typing import Final, Literal, final
from typing_extensions import TypeAlias
else:
from typing_extensions import Final, TypeAlias, final
from typing_extensions import Final, Literal, TypeAlias, final

import asyncpg.pgproto.pgproto

from ..compat import Literal
from ..connect_utils import _ConnectionParameters
from ..pgproto.pgproto import WriteBuffer
from ..types import Type, Attribute
Expand Down
13 changes: 7 additions & 6 deletions asyncpg/transaction.py
Expand Up @@ -11,11 +11,10 @@

# Work around https://github.com/microsoft/pyright/issues/3012
if sys.version_info >= (3, 8):
from typing import Final
from typing import Final, Literal
else:
from typing_extensions import Final
from typing_extensions import Final, Literal

from . import compat
from . import connresource
from . import exceptions as apg_errors

Expand All @@ -32,9 +31,11 @@ class TransactionState(enum.Enum):
FAILED = 4


IsolationLevels = compat.Literal['read_committed',
'serializable',
'repeatable_read']
IsolationLevels = Literal[
'read_committed',
'serializable',
'repeatable_read'
]
ISOLATION_LEVELS: Final[
typing.Set[IsolationLevels]
] = {
Expand Down

0 comments on commit fb6d8ef

Please sign in to comment.