Skip to content

Commit

Permalink
Clean up types
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Feb 10, 2022
1 parent 6a3b888 commit 23154c1
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 75 deletions.
30 changes: 21 additions & 9 deletions asyncpg/cluster.py
Expand Up @@ -68,6 +68,15 @@ class ClusterError(Exception):


class Cluster:
_data_dir: str
_pg_config_path: typing.Optional[str]
_pg_bin_dir: typing.Optional[str]
_pg_ctl: typing.Optional[str]
_daemon_pid: typing.Optional[int]
_daemon_process: typing.Optional['subprocess.Popen[bytes]']
_connection_addr: typing.Optional[_ConnectionSpec]
_connection_spec_override: typing.Optional[_ConnectionSpec]

def __init__(self, data_dir: str, *,
pg_config_path: typing.Optional[str] = None) -> None:
self._data_dir = data_dir
Expand All @@ -76,11 +85,11 @@ def __init__(self, data_dir: str, *,
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
self._pg_ctl: typing.Optional[str] = None
self._daemon_pid: typing.Optional[int] = None
self._daemon_process: typing.Optional[subprocess.Popen[bytes]] = None
self._connection_addr: typing.Optional[_ConnectionSpec] = None
self._connection_spec_override: typing.Optional[_ConnectionSpec] = None
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
self._connection_addr = None
self._connection_spec_override = None

def get_pg_version(self) -> 'types.ServerVersion':
return self._pg_version
Expand Down Expand Up @@ -653,6 +662,9 @@ def __init__(self, *,


class HotStandbyCluster(TempCluster):
_master: _ConnectionSpec
_repl_user: str

def __init__(self, *,
master: _ConnectionSpec, replication_user: str,
data_dir_suffix: typing.Optional[str] = None,
Expand Down Expand Up @@ -739,16 +751,16 @@ def get_status(self) -> str:
return 'running'

def init(self, **settings: str) -> str:
pass
...

def start(self, wait: int = 60, **settings: typing.Any) -> None:
pass
...

def stop(self, wait: int = 60) -> None:
pass
...

def destroy(self) -> None:
pass
...

def reset_hba(self) -> None:
raise ClusterError('cannot modify HBA records of unmanaged cluster')
Expand Down
79 changes: 55 additions & 24 deletions asyncpg/connection.py
Expand Up @@ -157,16 +157,41 @@ class Connection(typing.Generic[_Record], metaclass=ConnectionMeta):
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')

_protocol: '_cprotocol.BaseProtocol[_Record]'
_transport: typing.Any
_loop: asyncio.AbstractEventLoop
_top_xact: typing.Optional[transaction.Transaction]
_aborted: bool
_pool_release_ctr: int
_stmt_cache: '_StatementCache'
_stmts_to_close: typing.Set[
'_cprotocol.PreparedStatementState[typing.Any]'
]
_listeners: typing.Dict[str, typing.Set['_Callback']]
_server_version: types.ServerVersion
_server_caps: 'ServerCapabilities'
_intro_query: str
_reset_query: typing.Optional[str]
_proxy: typing.Optional['_pool.PoolConnectionProxy[typing.Any]']
_stmt_exclusive_section: '_Atomic'
_config: connect_utils._ClientConfiguration
_params: connect_utils._ConnectionParameters
_addr: typing.Union[typing.Tuple[str, int], str]
_log_listeners: typing.Set['_Callback']
_termination_listeners: typing.Set['_Callback']
_cancellations: typing.Set['asyncio.Task[typing.Any]']
_source_traceback: typing.Optional[str]

def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
transport: typing.Any,
loop: asyncio.AbstractEventLoop,
addr: typing.Union[typing.Tuple[str, int], str],
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters) -> None:
self._protocol: '_cprotocol.BaseProtocol[_Record]' = protocol
self._protocol = protocol
self._transport = transport
self._loop = loop
self._top_xact: typing.Optional[transaction.Transaction] = None
self._top_xact = None
self._aborted = False
# Incremented every time the connection is released back to a pool.
# Used to catch invalid references to connection-related resources
Expand All @@ -184,14 +209,12 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
_weak_maybe_gc_stmt, weakref.ref(self)),
max_lifetime=config.max_cached_statement_lifetime)

self._stmts_to_close: typing.Set[
'_cprotocol.PreparedStatementState[typing.Any]'
] = set()
self._stmts_to_close = set()

self._listeners: typing.Dict[str, typing.Set[_Callback]] = {}
self._log_listeners: typing.Set[_Callback] = set()
self._cancellations: typing.Set[asyncio.Task[typing.Any]] = set()
self._termination_listeners: typing.Set[_Callback] = set()
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand All @@ -206,10 +229,8 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
else:
self._intro_query = introspection.INTRO_LOOKUP_TYPES

self._reset_query: typing.Optional[str] = None
self._proxy: typing.Optional[
'_pool.PoolConnectionProxy[typing.Any]'
] = None
self._reset_query = None
self._proxy = None

# Used to serialize operations that might involve anonymous
# statements. Specifically, we want to make the following
Expand All @@ -221,7 +242,7 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
self._stmt_exclusive_section = _Atomic()

if loop.get_debug():
self._source_traceback: typing.Optional[str] = _extract_stack()
self._source_traceback = _extract_stack()
else:
self._source_traceback = None

Expand Down Expand Up @@ -2007,7 +2028,7 @@ def _set_proxy(
self._proxy = proxy

def _check_listeners(self,
listeners: 'typing.Sized',
listeners: typing.Sized,
listener_type: str) -> None:
if listeners:
count = len(listeners)
Expand Down Expand Up @@ -2915,6 +2936,11 @@ class _StatementCacheEntry(typing.Generic[_Record]):

__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')

_query: _StatementCacheKey[_Record]
_statement: '_cprotocol.PreparedStatementState[_Record]'
_cache: '_StatementCache'
_cleanup_cb: typing.Optional[asyncio.TimerHandle]

def __init__(
self,
cache: '_StatementCache',
Expand All @@ -2924,21 +2950,27 @@ def __init__(
self._cache = cache
self._query = query
self._statement = statement
self._cleanup_cb: typing.Optional[asyncio.TimerHandle] = None
self._cleanup_cb = None


class _StatementCache:

__slots__ = ('_loop', '_entries', '_max_size', '_on_remove',
'_max_lifetime')

_loop: asyncio.AbstractEventLoop
_entries: 'collections.OrderedDict[_StatementCacheKey[typing.Any], _StatementCacheEntry[typing.Any]]' # noqa: E501
_max_size: int
_on_remove: OnRemove[typing.Any]
_max_lifetime: float

def __init__(self, *, loop: asyncio.AbstractEventLoop,
max_size: int, on_remove: OnRemove[typing.Any],
max_lifetime: float) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._max_size: int = max_size
self._on_remove: OnRemove[typing.Any] = on_remove
self._max_lifetime: float = max_lifetime
self._loop = loop
self._max_size = max_size
self._on_remove = on_remove
self._max_lifetime = max_lifetime

# We use an OrderedDict for LRU implementation. Operations:
#
Expand All @@ -2957,10 +2989,7 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
# So new entries and hits are always promoted to the end of the
# entries dict, whereas the unused one will group in the
# beginning of it.
self._entries: collections.OrderedDict[
_StatementCacheKey[typing.Any],
_StatementCacheEntry[typing.Any]
] = collections.OrderedDict()
self._entries = collections.OrderedDict()

def __len__(self) -> int:
return len(self._entries)
Expand Down Expand Up @@ -3136,6 +3165,8 @@ def from_callable(
class _Atomic:
__slots__ = ('_acquired',)

_acquired: int

def __init__(self) -> None:
self._acquired = 0

Expand Down
7 changes: 5 additions & 2 deletions asyncpg/connresource.py
Expand Up @@ -13,7 +13,7 @@


if typing.TYPE_CHECKING:
from . import connection as _connection
from . import connection as _conn


_Callable = typing.TypeVar('_Callable', bound=typing.Callable[..., typing.Any])
Expand All @@ -35,8 +35,11 @@ def _check(self: 'ConnectionResource',
class ConnectionResource:
__slots__ = ('_connection', '_con_release_ctr')

_connection: '_conn.Connection[typing.Any]'
_con_release_ctr: int

def __init__(
self, connection: '_connection.Connection[typing.Any]'
self, connection: '_conn.Connection[typing.Any]'
) -> None:
self._connection = connection
self._con_release_ctr = connection._pool_release_ctr
Expand Down
40 changes: 29 additions & 11 deletions asyncpg/cursor.py
Expand Up @@ -39,12 +39,19 @@ class CursorFactory(connresource.ConnectionResource, typing.Generic[_Record]):
'_record_class',
)

_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
_args: typing.Sequence[typing.Any]
_prefetch: typing.Optional[int]
_query: str
_timeout: typing.Optional[float]
_record_class: typing.Optional[typing.Type[_Record]]

@typing.overload
def __init__(
self: 'CursorFactory[_Record]',
connection: '_connection.Connection[_Record]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
prefetch: typing.Optional[int],
timeout: typing.Optional[float],
Expand All @@ -57,7 +64,7 @@ def __init__(
self: 'CursorFactory[_Record]',
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
prefetch: typing.Optional[int],
timeout: typing.Optional[float],
Expand All @@ -69,7 +76,7 @@ def __init__(
self,
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
prefetch: typing.Optional[int],
timeout: typing.Optional[float],
Expand Down Expand Up @@ -130,12 +137,19 @@ class BaseCursor(connresource.ConnectionResource, typing.Generic[_Record]):
'_record_class',
)

_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
_args: typing.Sequence[typing.Any]
_portal_name: typing.Optional[str]
_exhausted: bool
_query: str
_record_class: typing.Optional[typing.Type[_Record]]

@typing.overload
def __init__(
self: 'BaseCursor[_Record]',
connection: '_connection.Connection[_Record]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: None
) -> None:
Expand All @@ -146,7 +160,7 @@ def __init__(
self: 'BaseCursor[_Record]',
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: typing.Type[_Record]
) -> None:
Expand All @@ -156,7 +170,7 @@ def __init__(
self,
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: typing.Optional[typing.Type[_Record]]
) -> None:
Expand All @@ -165,7 +179,7 @@ def __init__(
self._state = state
if state is not None:
state.attach()
self._portal_name: typing.Optional[str] = None
self._portal_name = None
self._exhausted = False
self._query = query
self._record_class = record_class
Expand Down Expand Up @@ -260,12 +274,16 @@ class CursorIterator(BaseCursor[_Record]):

__slots__ = ('_buffer', '_prefetch', '_timeout')

_buffer: typing.Deque[_Record]
_prefetch: int
_timeout: typing.Optional[float]

@typing.overload
def __init__(
self: 'CursorIterator[_Record]',
connection: '_connection.Connection[_Record]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: None,
prefetch: int,
Expand All @@ -278,7 +296,7 @@ def __init__(
self: 'CursorIterator[_Record]',
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: typing.Type[_Record],
prefetch: int,
Expand All @@ -290,7 +308,7 @@ def __init__(
self,
connection: '_connection.Connection[typing.Any]',
query: str,
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
args: typing.Sequence[typing.Any],
record_class: typing.Optional[typing.Type[_Record]],
prefetch: int,
Expand All @@ -302,7 +320,7 @@ def __init__(
raise exceptions.InterfaceError(
'prefetch argument must be greater than zero')

self._buffer: typing.Deque[_Record] = collections.deque()
self._buffer = collections.deque()
self._prefetch = prefetch
self._timeout = timeout

Expand Down

0 comments on commit 23154c1

Please sign in to comment.