Skip to content

Commit

Permalink
Tweaks for pyright and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanforbes committed Jul 26, 2022
1 parent 292290d commit 96c0a4c
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 341 deletions.
7 changes: 3 additions & 4 deletions asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from asyncpg import exceptions

if typing.TYPE_CHECKING:
import _typeshed
from . import connection
from . import types

Expand Down Expand Up @@ -652,8 +653,7 @@ class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix: typing.Optional[str] = None,
data_dir_prefix: typing.Optional[str] = None,
data_dir_parent: typing.Optional[
'tempfile._DirT[str]'] = None,
data_dir_parent: typing.Optional['_typeshed.StrPath'] = None,
pg_config_path: typing.Optional[str] = None) -> None:
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
Expand All @@ -669,8 +669,7 @@ def __init__(self, *,
master: _ConnectionSpec, replication_user: str,
data_dir_suffix: typing.Optional[str] = None,
data_dir_prefix: typing.Optional[str] = None,
data_dir_parent: typing.Optional[
'tempfile._DirT[str]'] = None,
data_dir_parent: typing.Optional['_typeshed.StrPath'] = None,
pg_config_path: typing.Optional[str] = None) -> None:
self._master = master
self._repl_user = replication_user
Expand Down
16 changes: 12 additions & 4 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import enum
import functools
import getpass
import inspect
import os
import pathlib
import platform
Expand Down Expand Up @@ -268,7 +269,9 @@ def _parse_hostlist(hostlist: str,
hostspec_port = urllib.parse.unquote(hostspec_port)
hostlist_ports.append(int(hostspec_port))
else:
hostlist_ports.append(default_port[i])
hostlist_ports.append(
default_port[i] # pyright: ignore [reportUnknownArgumentType, reportUnboundVariable] # noqa: E501
)

if not ports:
ports = hostlist_ports
Expand Down Expand Up @@ -892,7 +895,7 @@ async def _connect_addr(
if inspect.isawaitable(password):
password = await password

params = params._replace(password=password)
params = params._replace(password=typing.cast(str, password))
args = (addr, loop, config, connection_class, record_class, params_input)

# prepare the params (which attempt has ssl) for the 2 attempts
Expand Down Expand Up @@ -954,8 +957,13 @@ async def __connect_addr(
elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
connector = typing.cast(
typing.Coroutine[
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
],
loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)
)

elif params.ssl:
Expand Down

0 comments on commit 96c0a4c

Please sign in to comment.