Skip to content

Commit

Permalink
Typing improvements and updates for code in master
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Oct 13, 2020
1 parent f091228 commit 8ac22c1
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 33 deletions.
2 changes: 1 addition & 1 deletion asyncpg/cluster.py
Expand Up @@ -36,7 +36,7 @@ class _ConnectionSpec(typing_extensions.TypedDict):
port: str


_system = platform.uname().system
_system: typing_extensions.Final = platform.uname().system

if _system == 'Windows':
def platform_exe(name: str) -> str:
Expand Down
7 changes: 4 additions & 3 deletions asyncpg/compat.py
Expand Up @@ -10,12 +10,13 @@
import platform
import sys
import typing
import typing_extensions


_T = typing.TypeVar('_T')
PY_36 = sys.version_info >= (3, 6)
PY_37 = sys.version_info >= (3, 7)
SYSTEM = platform.uname().system
PY_36: typing_extensions.Final = sys.version_info >= (3, 6)
PY_37: typing_extensions.Final = sys.version_info >= (3, 7)
SYSTEM: typing_extensions.Final = platform.uname().system


if SYSTEM == 'Windows':
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/connect_utils.py
Expand Up @@ -56,7 +56,7 @@ class _ClientConfiguration(typing.NamedTuple):
max_cacheable_statement_size: int


_system = platform.uname().system
_system: typing_extensions.Final = platform.uname().system


if _system == 'Windows':
Expand Down
21 changes: 13 additions & 8 deletions asyncpg/connection.py
Expand Up @@ -320,7 +320,9 @@ def get_settings(self) -> '_cprotocol.ConnectionSettings':
return self._protocol.get_settings()

def transaction(self, *,
isolation: transaction.IsolationLevels = 'read_committed',
isolation: typing.Optional[
transaction.IsolationLevels
] = None,
readonly: bool = False,
deferrable: bool = False) -> transaction.Transaction:
"""Create a :class:`~transaction.Transaction` object.
Expand Down Expand Up @@ -1749,7 +1751,7 @@ def _maybe_gc_stmt(
if (
stmt.refs == 0
and not self._stmt_cache.has(
(stmt.query, stmt.record_class, bool(stmt.ignore_custom_codec))
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
)
):
# If low-level `stmt` isn't referenced from any high-level
Expand Down Expand Up @@ -2736,14 +2738,17 @@ async def connect(dsn: typing.Optional[str] = None, *,
)


_StatementCacheKey = typing.Tuple[str, typing.Type[_Record], bool]


class _StatementCacheEntry:

__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')

def __init__(
self,
cache: '_StatementCache',
query: typing.Tuple[str, typing.Type[_Record], bool],
query: _StatementCacheKey[_Record],
statement: '_cprotocol.PreparedStatementState[_Record]'
) -> None:
self._cache = cache
Expand Down Expand Up @@ -2783,7 +2788,7 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
# entries dict, whereas the unused one will group in the
# beginning of it.
self._entries: collections.OrderedDict[
typing.Tuple[str, typing.Type['_cprotocol.Record'], bool],
_StatementCacheKey['_cprotocol.Record'],
_StatementCacheEntry
] = collections.OrderedDict()

Expand Down Expand Up @@ -2811,7 +2816,7 @@ def set_max_lifetime(self, new_lifetime: float) -> None:

def get(
self,
query: typing.Tuple[str, typing.Type[_Record], bool],
query: _StatementCacheKey[_Record],
*,
promote: bool = True
) -> typing.Optional['_cprotocol.PreparedStatementState[_Record]']:
Expand All @@ -2837,12 +2842,12 @@ def get(

return entry._statement

def has(self, query: typing.Tuple[str, typing.Type[_Record], bool]) -> bool:
def has(self, query: _StatementCacheKey[_Record]) -> bool:
return self.get(query, promote=False) is not None

def put(
self,
query: typing.Tuple[str, typing.Type[_Record], bool],
query: _StatementCacheKey[_Record],
statement: '_cprotocol.PreparedStatementState[_Record]'
) -> None:
if not self._max_size:
Expand Down Expand Up @@ -2884,7 +2889,7 @@ def _set_entry_timeout(self, entry: _StatementCacheEntry) -> None:

def _new_entry(
self,
query: typing.Tuple[str, typing.Type[_Record], bool],
query: _StatementCacheKey[_Record],
statement: '_cprotocol.PreparedStatementState[_Record]'
) -> _StatementCacheEntry:
entry = _StatementCacheEntry(self, query, statement)
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/cursor.py
Expand Up @@ -109,7 +109,7 @@ def __await__(self) -> typing.Generator[
self._query,
self._state,
self._args,
self._record_class
self._record_class,
)
return cursor._init(self._timeout).__await__()

Expand Down
11 changes: 6 additions & 5 deletions asyncpg/introspection.py
Expand Up @@ -6,12 +6,13 @@


import typing
import typing_extensions

if typing.TYPE_CHECKING:
from . import protocol


_TYPEINFO = '''\
_TYPEINFO: typing_extensions.Final = '''\
(
SELECT
t.oid AS oid,
Expand Down Expand Up @@ -102,7 +103,7 @@
'''


INTRO_LOOKUP_TYPES = '''\
INTRO_LOOKUP_TYPES: typing_extensions.Final = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
Expand Down Expand Up @@ -140,7 +141,7 @@
'''.format(typeinfo=_TYPEINFO)


TYPE_BY_NAME = '''\
TYPE_BY_NAME: typing_extensions.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
Expand All @@ -153,7 +154,7 @@
'''


TYPE_BY_OID = '''\
TYPE_BY_OID: typing_extensions.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
Expand All @@ -166,7 +167,7 @@


# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
SCALAR_TYPE_KINDS: typing_extensions.Final = (b'b', b'd', b'e')


def is_scalar_type(typeinfo: 'protocol.Record') -> bool:
Expand Down
14 changes: 7 additions & 7 deletions asyncpg/protocol/protocol.pyi
Expand Up @@ -18,7 +18,7 @@ from typing import (
Union,
overload,
)
from typing_extensions import Protocol as _TEProtocol, Literal
from typing_extensions import Protocol as _TEProtocol, Literal, Final

import asyncpg.pgproto.pgproto

Expand All @@ -33,9 +33,9 @@ _OtherRecord = TypeVar('_OtherRecord', bound=Record)
_PreparedStatementState = TypeVar('_PreparedStatementState',
bound=PreparedStatementState[Any])

BUILTIN_TYPE_NAME_MAP: Dict[str, int]
BUILTIN_TYPE_OID_MAP: Dict[int, str]
NO_TIMEOUT: _NoTimeoutType
BUILTIN_TYPE_NAME_MAP: Final[Dict[str, int]]
BUILTIN_TYPE_OID_MAP: Final[Dict[int, str]]
NO_TIMEOUT: Final[_NoTimeoutType]

def hashlib_md5(*args: Any, **kwargs: Any) -> Any: ...

Expand Down Expand Up @@ -186,7 +186,7 @@ class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
format: Any,
) -> Any: ...
def clear_type_cache(self) -> None: ...
def get_data_codec(self, oid: int, format: Any = ...) -> Any: ...
def get_data_codec(self, oid: int, format: Any = ..., ignore_custom_codec: bool = ...) -> Any: ...
def get_text_codec(self) -> CodecInfo: ...
def register_data_types(self, types: Iterable[Any]) -> None: ...
def remove_python_codec(
Expand Down Expand Up @@ -247,15 +247,15 @@ class PreparedStatementState(Generic[_Record]):
query: str = ...
refs: int = ...
record_class: _TypingType[_Record] = ...
ignore_custom_codec: int = ...
ignore_custom_codec: bool = ...
__pyx_vtable__: Any = ...
def __init__(
self,
name: str,
query: str,
protocol: BaseProtocol[Any],
record_class: _TypingType[_Record],
ignore_custom_codec: int,
ignore_custom_codec: bool,
) -> None: ...
def _get_parameters(self) -> Tuple[Type, ...]: ...
def _get_attributes(self) -> Tuple[Attribute, ...]: ...
Expand Down
18 changes: 12 additions & 6 deletions asyncpg/transaction.py
Expand Up @@ -28,10 +28,16 @@ class TransactionState(enum.Enum):
IsolationLevels = typing_extensions.Literal['read_committed',
'serializable',
'repeatable_read']
ISOLATION_LEVELS: typing.Set[IsolationLevels] = {'read_committed',
'serializable',
'repeatable_read'}
ISOLATION_LEVELS_BY_VALUE: typing.Dict[str, IsolationLevels] = {
ISOLATION_LEVELS: typing_extensions.Final[
typing.Set[IsolationLevels]
] = {
'read_committed',
'serializable',
'repeatable_read'
}
ISOLATION_LEVELS_BY_VALUE: typing_extensions.Final[
typing.Dict[str, IsolationLevels]
] = {
'read committed': 'read_committed',
'serializable': 'serializable',
'repeatable read': 'repeatable_read',
Expand All @@ -50,7 +56,7 @@ class Transaction(connresource.ConnectionResource):
'_state', '_nested', '_id', '_managed')

def __init__(self, connection: '_connection.Connection[typing.Any]',
isolation: IsolationLevels,
isolation: typing.Optional[IsolationLevels],
readonly: bool, deferrable: bool) -> None:
super().__init__(connection)

Expand Down Expand Up @@ -249,7 +255,7 @@ def __repr__(self) -> str:
attrs = []
attrs.append('state:{}'.format(self._state.name.lower()))

attrs.append(self._isolation)
attrs.append(str(self._isolation))
if self._readonly:
attrs.append('readonly')
if self._deferrable:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -35,7 +35,7 @@
'pycodestyle~=2.6.0',
'flake8~=3.8.2',
'uvloop~=0.14.0;platform_system!="Windows"',
'mypy>=0.780'
'mypy>=0.790'
]

# Dependencies required to build documentation.
Expand Down

0 comments on commit 8ac22c1

Please sign in to comment.