Skip to content

Commit

Permalink
Address feedback and fix runtime errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Jul 19, 2020
1 parent 3bf8c02 commit 4fef52f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
16 changes: 11 additions & 5 deletions asyncpg/cluster.py
Expand Up @@ -131,9 +131,10 @@ def get_status(self) -> str:
async def connect(self,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
**kwargs: typing.Any) -> 'connection.Connection':
conn_info: typing.Optional[typing.Any] = self.get_connection_spec()
conn_info.update(kwargs) # type: ignore[union-attr]
return await asyncpg.connect(loop=loop, **conn_info) # type: ignore[misc] # noqa: E501
conn_info = typing.cast(typing.Dict[str, typing.Any],
self.get_connection_spec())
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)

def init(self, **settings: str) -> str:
"""Initialize cluster."""
Expand Down Expand Up @@ -307,12 +308,17 @@ def _get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:

return None

def get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
def get_connection_spec(self) -> _ConnectionSpec:
status = self.get_status()
if status != 'running':
raise ClusterError('cluster is not running')

return self._get_connection_spec()
spec = self._get_connection_spec()

if spec is None:
raise ClusterError('cannot determine server connection address')

return spec

def override_connection_spec(self, **kwargs: str) -> None:
self._connection_spec_override = typing.cast(_ConnectionSpec, kwargs)
Expand Down
4 changes: 2 additions & 2 deletions asyncpg/connection.py
Expand Up @@ -45,11 +45,11 @@
_AnyCallable = typing.Callable[..., typing.Any]

OutputType = typing.Union[typing.AnyStr,
os.PathLike[typing.AnyStr],
os.PathLike,
typing.IO[typing.AnyStr],
_Writer]
SourceType = typing.Union[typing.AnyStr,
os.PathLike[typing.AnyStr],
os.PathLike,
typing.IO[typing.AnyStr],
typing.AsyncIterable[bytes]]

Expand Down
5 changes: 0 additions & 5 deletions tests/test_copy.py
Expand Up @@ -13,7 +13,6 @@

import asyncpg
from asyncpg import _testbase as tb
from asyncpg import compat


class TestCopyFrom(tb.ConnectedTestCase):
Expand Down Expand Up @@ -467,7 +466,6 @@ class _Source:
def __init__(self):
self.rowcount = 0

@compat.aiter_compat
def __aiter__(self):
return self

Expand Down Expand Up @@ -507,7 +505,6 @@ class _Source:
def __init__(self):
self.rowcount = 0

@compat.aiter_compat
def __aiter__(self):
return self

Expand All @@ -533,7 +530,6 @@ class _Source:
def __init__(self):
self.rowcount = 0

@compat.aiter_compat
def __aiter__(self):
return self

Expand Down Expand Up @@ -564,7 +560,6 @@ def __init__(self, loop):
self.rowcount = 0
self.loop = loop

@compat.aiter_compat
def __aiter__(self):
return self

Expand Down

0 comments on commit 4fef52f

Please sign in to comment.