Skip to content

Commit

Permalink
Updates for custom records and method updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Aug 21, 2020
1 parent 86a92f9 commit 8332c03
Show file tree
Hide file tree
Showing 11 changed files with 949 additions and 190 deletions.
8 changes: 6 additions & 2 deletions asyncpg/cluster.py
Expand Up @@ -130,11 +130,15 @@ def get_status(self) -> str:

async def connect(self,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
**kwargs: typing.Any) -> 'connection.Connection':
**kwargs: typing.Any) \
-> 'connection.Connection[typing.Any]':
conn_info = typing.cast(typing.Dict[str, typing.Any],
self.get_connection_spec())
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
return typing.cast(
'connection.Connection[typing.Any]',
await asyncpg.connect(loop=loop, **conn_info)
)

def init(self, **settings: str) -> str:
"""Initialize cluster."""
Expand Down
9 changes: 5 additions & 4 deletions asyncpg/connect_utils.py
Expand Up @@ -29,6 +29,7 @@

_Connection = typing.TypeVar('_Connection')
_Protocol = typing.TypeVar('_Protocol', bound=asyncio.Protocol)
_Record = typing.TypeVar('_Record', bound=protocol.Record)

_TPTupleType = typing.Tuple[asyncio.WriteTransport, _Protocol]
AddrType = typing.Union[typing.Tuple[str, int], str]
Expand Down Expand Up @@ -654,7 +655,7 @@ async def _connect_addr(
params: _ConnectionParameters,
config: _ClientConfiguration,
connection_class: typing.Type[_Connection],
record_class: typing.Any
record_class: typing.Type[_Record]
) -> _Connection:
assert loop is not None

Expand All @@ -680,7 +681,7 @@ async def _connect_addr(
assert not params.ssl
connector = typing.cast(
typing.Coroutine[typing.Any, None,
_TPTupleType[protocol.Protocol]],
_TPTupleType['protocol.Protocol[_Record]']],
loop.create_unix_connection(proto_factory, addr))
elif params.ssl:
connector = _create_ssl_connection(
Expand All @@ -689,7 +690,7 @@ async def _connect_addr(
else:
connector = typing.cast(
typing.Coroutine[typing.Any, None,
_TPTupleType[protocol.Protocol]],
_TPTupleType['protocol.Protocol[_Record]']],
loop.create_connection(proto_factory, *addr))

connector_future = asyncio.ensure_future(connector)
Expand Down Expand Up @@ -721,7 +722,7 @@ async def _connect(
loop: typing.Optional[asyncio.AbstractEventLoop],
timeout: float,
connection_class: typing.Type[_Connection],
record_class: typing.Any,
record_class: typing.Type[_Record],
**kwargs: typing.Any
) -> _Connection:
if loop is None:
Expand Down

0 comments on commit 8332c03

Please sign in to comment.