Skip to content

Commit

Permalink
Use specific ignore comments and updates for mypy 0.780
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Jun 8, 2020
1 parent 38b8c60 commit d449a6f
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -34,6 +34,6 @@ docs/_build
/.eggs
/.vscode
/.mypy_cache
/.venv
/.venv*
/.ci
/.vim
16 changes: 8 additions & 8 deletions asyncpg/cluster.py
Expand Up @@ -126,15 +126,15 @@ def get_status(self) -> str:
return self._test_connection(timeout=0)
else:
raise ClusterError(
'pg_ctl status exited with status {:d}: {}'.format(
'pg_ctl status exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
process.returncode, stderr))

async def connect(self,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
**kwargs: typing.Any) -> 'connection.Connection':
conn_info = self.get_connection_spec() # type: typing.Optional[typing.Any] # noqa: E501
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
conn_info.update(kwargs) # type: ignore[union-attr]
return await asyncpg.connect(loop=loop, **conn_info) # type: ignore[misc] # noqa: E501

def init(self, **settings: str) -> str:
"""Initialize cluster."""
Expand Down Expand Up @@ -301,7 +301,7 @@ def _get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
if self._connection_addr is not None:
if self._connection_spec_override:
args = self._connection_addr.copy()
args.update(self._connection_spec_override) # type: ignore
args.update(self._connection_spec_override) # type: ignore[arg-type] # noqa: E501
return args
else:
return self._connection_addr
Expand Down Expand Up @@ -401,7 +401,7 @@ def add_hba_entry(self, *, type: str = 'host',

if auth_options is not None:
record += ' ' + ' '.join(
'{}={}'.format(k, v) for k, v in auth_options)
'{}={}'.format(k, v) for k, v in auth_options.items())

try:
with open(pg_hba, 'a') as f:
Expand Down Expand Up @@ -516,7 +516,7 @@ def _test_connection(self, timeout: int = 60) -> str:

try:
con = loop.run_until_complete(
asyncpg.connect(database='postgres',
asyncpg.connect(database='postgres', # type: ignore[misc] # noqa: E501
user='postgres',
timeout=5, loop=loop,
**self._connection_addr))
Expand Down Expand Up @@ -544,7 +544,7 @@ def _run_pg_config(self, pg_config_path: str) -> typing.Dict[str, str]:
stdout, stderr = process.stdout, process.stderr

if process.returncode != 0:
raise ClusterError('pg_config exited with status {:d}: {}'.format(
raise ClusterError('pg_config exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
process.returncode, stderr))
else:
config = {}
Expand Down Expand Up @@ -601,7 +601,7 @@ def _get_pg_version(self) -> 'types.ServerVersion':

if process.returncode != 0:
raise ClusterError(
'postgres --version exited with status {:d}: {}'.format(
'postgres --version exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
process.returncode, stderr))

version_string = stdout.decode('utf-8').strip(' \n')
Expand Down
4 changes: 2 additions & 2 deletions asyncpg/compat.py
Expand Up @@ -34,7 +34,7 @@ async def wrapper(self: typing.Any) -> typing.Any:
return func(self)
return typing.cast(_F_35, wrapper)
else:
def aiter_compat(func: _F) -> _F: # type: ignore
def aiter_compat(func: _F) -> _F: # type: ignore[misc]
return func


Expand Down Expand Up @@ -88,7 +88,7 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) # type: ignore # noqa: E501
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) # type: ignore[attr-defined] # noqa: E501
if r:
return None
else:
Expand Down
18 changes: 9 additions & 9 deletions asyncpg/connect_utils.py
Expand Up @@ -317,7 +317,7 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
if 'sslmode' in query_str:
val_str = query_str.pop('sslmode')
if ssl is None:
ssl = val_str
ssl = val_str # type: ignore[assignment]

if query_str:
if server_settings is None:
Expand Down Expand Up @@ -392,17 +392,17 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
if passfile is None:
homedir = compat.get_pg_home_directory()
if homedir:
passfile = homedir / PGPASSFILE
passfile = homedir / PGPASSFILE # type: ignore[assignment]
else:
passfile = None
else:
passfile = pathlib.Path(passfile)
passfile = pathlib.Path(passfile) # type: ignore[assignment]

if passfile is not None:
password = _read_password_from_pgpass(
hosts=auth_hosts, ports=port,
database=database, user=user,
passfile=passfile)
passfile=passfile) # type: ignore[arg-type]

addrs = [] # type: typing.List[AddrType]
for h, p in zip(host, port):
Expand All @@ -420,7 +420,7 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
'could not determine the database address to connect to')

if ssl is None:
ssl = os.getenv('PGSSLMODE')
ssl = os.getenv('PGSSLMODE') # type: ignore[assignment]

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
Expand Down Expand Up @@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory: typing.Callable[[],
typing.Tuple[asyncio.WriteTransport, TLSUpgradeProto],
await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
typing.cast(ssl_module.SSLContext,
typing.cast(typing.Any,
ssl_context),
ssl_is_advisory),
host, port))
Expand All @@ -614,7 +614,7 @@ async def _create_ssl_connection(protocol_factory: typing.Callable[[],
asyncio.WriteTransport,
await typing.cast(typing.Any, loop).start_tls(
tr, pr,
typing.cast(ssl_module.SSLContext, ssl_context),
typing.cast(typing.Any, ssl_context),
server_hostname=host))
except (Exception, asyncio.CancelledError):
tr.close()
Expand Down Expand Up @@ -711,7 +711,7 @@ async def _connect_addr(*, addr: AddrType,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, config, # type: ignore
con = connection_class(pr, tr, loop, addr, config, # type: ignore[call-arg] # noqa: E501
params_input)
pr.set_connection(con)
return con
Expand Down Expand Up @@ -805,7 +805,7 @@ def _set_nodelay(sock: typing.Any) -> None:
def _create_future(loop: typing.Optional[asyncio.AbstractEventLoop]) \
-> 'asyncio.Future[typing.Any]':
try:
create_future = loop.create_future # type: ignore
create_future = loop.create_future # type: ignore[union-attr]
except AttributeError:
return asyncio.Future(loop=loop)
else:
Expand Down
28 changes: 15 additions & 13 deletions asyncpg/connection.py
Expand Up @@ -38,6 +38,7 @@
_Connection = typing.TypeVar('_Connection', bound='Connection')
_Writer = typing.Callable[[bytes],
typing.Coroutine[typing.Any, typing.Any, None]]
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')
_RecordsType = typing.List['_cprotocol.Record']
_RecordsExtraType = typing.Tuple[_RecordsType, bytes, bool]
_AnyCallable = typing.Callable[..., typing.Any]
Expand Down Expand Up @@ -447,7 +448,8 @@ async def _introspect_types(self, typeoids: typing.Set[int],

def cursor(self, query: str, *args: typing.Any,
prefetch: typing.Optional[int] = None,
timeout: typing.Optional[float] = None) -> cursor.CursorFactory:
timeout: typing.Optional[float] = None) \
-> 'cursor.CursorFactory[_cprotocol.Record]':
"""Return a *cursor factory* for the specified query.
:param args: Query arguments.
Expand All @@ -463,7 +465,7 @@ def cursor(self, query: str, *args: typing.Any,

async def prepare(self, query: str, *,
timeout: typing.Optional[float] = None) \
-> prepared_stmt.PreparedStatement:
-> prepared_stmt.PreparedStatement['_cprotocol.Record']:
"""Create a *prepared statement* for the specified query.
:param str query: Text of the query to create a prepared statement for.
Expand All @@ -476,7 +478,7 @@ async def prepare(self, query: str, *,
async def _prepare(self, query: str, *,
timeout: typing.Optional[float] = None,
use_cache: bool = False) \
-> prepared_stmt.PreparedStatement:
-> prepared_stmt.PreparedStatement['_cprotocol.Record']:
self._check_open()
stmt = await self._get_statement(query, timeout, named=True,
use_cache=use_cache)
Expand Down Expand Up @@ -886,7 +888,7 @@ async def _copy_out(self, copy_stmt: str,
output: OutputType[typing.AnyStr],
timeout: typing.Optional[float]) -> str:
try:
path = compat.fspath(output) # type: typing.Optional[typing.AnyStr] # type: ignore # noqa: E501
path = compat.fspath(output) # type: typing.Optional[typing.AnyStr] # type: ignore[arg-type] # noqa: E501
except TypeError:
# output is not a path-like object
path = None
Expand All @@ -913,7 +915,7 @@ async def _copy_out(self, copy_stmt: str,
)

if writer is None:
async def _writer(data: bytes) -> None: # type: ignore
async def _writer(data: bytes) -> None: # type: ignore[return]
await run_in_executor(None, f.write, data)

writer = _writer
Expand All @@ -928,7 +930,7 @@ async def _copy_in(self, copy_stmt: str,
source: SourceType[typing.AnyStr],
timeout: typing.Optional[float]) -> str:
try:
path = compat.fspath(source) # type: typing.Optional[typing.AnyStr] # type: ignore # noqa: E501
path = compat.fspath(source) # type: typing.Optional[typing.AnyStr] # type: ignore[arg-type] # noqa: E501
except TypeError:
# source is not a path-like object
path = None
Expand Down Expand Up @@ -967,7 +969,7 @@ async def __anext__(self) -> bytes:
if len(data) == 0:
raise StopAsyncIteration
else:
return data # type: ignore
return data # type: ignore[return-value]

reader = _Reader()

Expand Down Expand Up @@ -1259,7 +1261,7 @@ def _abort(self) -> None:
# Put the connection into the aborted state.
self._aborted = True
self._protocol.abort()
self._protocol = None # type: ignore
self._protocol = None # type: ignore[assignment]

def _cleanup(self) -> None:
# Free the resources associated with this connection.
Expand Down Expand Up @@ -1352,7 +1354,7 @@ async def _cancel(self, waiter: 'asyncio.Future[None]') -> None:
waiter.set_exception(ex)
finally:
self._cancellations.discard(
compat.current_asyncio_task(self._loop))
compat.current_asyncio_task(self._loop)) # type: ignore[arg-type] # noqa: E501
if not waiter.done():
waiter.set_result(None)

Expand Down Expand Up @@ -1747,7 +1749,7 @@ async def connect(dsn: typing.Optional[str] = None, *,
max_cacheable_statement_size: int = 1024 * 15,
command_timeout: typing.Optional[float] = None,
ssl: typing.Optional[connect_utils.SSLType] = None,
connection_class: typing.Type[_Connection] = Connection, # type: ignore # noqa: E501
connection_class: typing.Type[_Connection] = Connection, # type: ignore[assignment] # noqa: E501
server_settings: typing.Optional[
typing.Dict[str, str]] = None) -> _Connection:
r"""A coroutine to establish a connection to a PostgreSQL server.
Expand Down Expand Up @@ -2180,15 +2182,15 @@ def _extract_stack(limit: int = 10) -> str:
frame = sys._getframe().f_back
try:
stack = traceback.StackSummary.extract(
traceback.walk_stack(frame), lookup_lines=False) # type: typing.Union[traceback.StackSummary, typing.List[traceback.FrameSummary]] # noqa: E501
traceback.walk_stack(frame), lookup_lines=False) # type: ignore[arg-type] # noqa: E501
finally:
del frame

apg_path = asyncpg.__path__[0]
apg_path = asyncpg.__path__[0] # type: ignore[attr-defined]
i = 0
while i < len(stack) and stack[i][0].startswith(apg_path):
i += 1
stack = stack[i:i + limit]
stack = stack[i:i + limit] # type: ignore[assignment]

stack.reverse()
return ''.join(traceback.format_list(stack))
Expand Down
10 changes: 5 additions & 5 deletions asyncpg/cursor.py
Expand Up @@ -24,7 +24,7 @@
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')


class CursorFactory(connresource.ConnectionResource):
class CursorFactory(connresource.ConnectionResource, typing.Generic[_Record]):
"""A cursor interface for the results of a query.
A cursor interface can be used to initiate efficient traversal of the
Expand All @@ -49,7 +49,7 @@ def __init__(self, connection: '_connection.Connection', query: str,

@compat.aiter_compat
@connresource.guarded
def __aiter__(self) -> 'CursorIterator[_cprotocol.Record]':
def __aiter__(self) -> 'CursorIterator[_Record]':
prefetch = 50 if self._prefetch is None else self._prefetch
return CursorIterator(self._connection,
self._query, self._state,
Expand All @@ -58,13 +58,13 @@ def __aiter__(self) -> 'CursorIterator[_cprotocol.Record]':

@connresource.guarded
def __await__(self) -> typing.Generator[
typing.Any, None, 'Cursor[_cprotocol.Record]']:
typing.Any, None, 'Cursor[_Record]']:
if self._prefetch is not None:
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
cursor = Cursor(self._connection, self._query,
self._state,
self._args) # type: Cursor[_cprotocol.Record]
self._args) # type: Cursor[_Record]
return cursor._init(self._timeout).__await__()

def __del__(self) -> None:
Expand Down Expand Up @@ -166,7 +166,7 @@ def __repr__(self) -> str:

return '<{}.{} "{!s:.30}" {}{:#x}>'.format(
mod, self.__class__.__name__,
self._state.query,
self._state.query, # type: ignore[union-attr]
' '.join(attrs), id(self))

def __del__(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions asyncpg/pool.pyi
Expand Up @@ -43,10 +43,11 @@ class PoolConnectionProxy(connection._ConnectionProxy,
*, timeout: typing.Optional[float] = ...) -> None: ...
def cursor(self, query: str, *args: typing.Any,
prefetch: typing.Optional[int] = ...,
timeout: typing.Optional[float] = ...) -> cursor.CursorFactory: ...
timeout: typing.Optional[float] = ...) \
-> cursor.CursorFactory[_cprotocol.Record]: ...
async def prepare(self, query: str, *,
timeout: typing.Optional[float] = ...) \
-> prepared_stmt.PreparedStatement: ...
-> prepared_stmt.PreparedStatement[_cprotocol.Record]: ...
async def fetch(self, query: str, *args: typing.Any,
timeout: typing.Optional[float] = ...) \
-> typing.List[_cprotocol.Record]: ...
Expand Down
17 changes: 11 additions & 6 deletions asyncpg/prepared_stmt.py
Expand Up @@ -19,7 +19,11 @@
from . import connection as _connection


class PreparedStatement(connresource.ConnectionResource):
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')


class PreparedStatement(connresource.ConnectionResource,
typing.Generic[_Record]):
"""A representation of a prepared statement."""

__slots__ = ('_state', '_query', '_last_status')
Expand Down Expand Up @@ -101,7 +105,8 @@ def get_attributes(self) -> typing.Tuple[types.Attribute, ...]:

@connresource.guarded
def cursor(self, *args: typing.Any, prefetch: typing.Optional[int] = None,
timeout: typing.Optional[float] = None) -> cursor.CursorFactory:
timeout: typing.Optional[float] = None) \
-> cursor.CursorFactory[_Record]:
"""Return a *cursor factory* for the prepared statement.
:param args: Query arguments.
Expand Down Expand Up @@ -161,7 +166,7 @@ async def explain(self, *args: typing.Any,
@connresource.guarded
async def fetch(self, *args: typing.Any,
timeout: typing.Optional[float] = None) \
-> typing.List['_cprotocol.Record']:
-> typing.List[_Record]:
r"""Execute the statement and return a list of :class:`Record` objects.
:param str query: Query text
Expand Down Expand Up @@ -196,7 +201,7 @@ async def fetchval(self, *args: typing.Any, column: int = 0,
@connresource.guarded
async def fetchrow(self, *args: typing.Any,
timeout: typing.Optional[float] = None) \
-> typing.Optional['_cprotocol.Record']:
-> typing.Optional[_Record]:
"""Execute the statement and return the first row.
:param str query: Query text
Expand All @@ -213,7 +218,7 @@ async def fetchrow(self, *args: typing.Any,
async def __bind_execute(self, args: typing.Tuple[typing.Any, ...],
limit: int,
timeout: typing.Optional[float]) \
-> typing.List['_cprotocol.Record']:
-> typing.List[_Record]:
protocol = self._connection._protocol
try:
data, status, _ = await protocol.bind_execute(
Expand All @@ -227,7 +232,7 @@ async def __bind_execute(self, args: typing.Tuple[typing.Any, ...],
self._state.mark_closed()
raise
self._last_status = status
return data
return data # type: ignore[return-value]

def _check_open(self, meth_name: str) -> None:
if self._state.closed:
Expand Down

0 comments on commit d449a6f

Please sign in to comment.