Skip to content

Commit

Permalink
lint: fix type hints with disabled bytearray/memoryview/bytes equival…
Browse files Browse the repository at this point in the history
…ience

This changeset makes the code compatible with the current Mypy 0.981,
but passes most of the checks that Mypy 0.990 enforces if the byte
strings equivalence is disabled.
  • Loading branch information
dvarrazzo committed Nov 4, 2022
1 parent b290b2e commit 5acb614
Show file tree
Hide file tree
Showing 18 changed files with 89 additions and 78 deletions.
4 changes: 2 additions & 2 deletions psycopg/psycopg/_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from ._compat import Protocol, TypeAlias

PackInt: TypeAlias = Callable[[int], bytes]
UnpackInt: TypeAlias = Callable[[bytes], Tuple[int]]
UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
PackFloat: TypeAlias = Callable[[float], bytes]
UnpackFloat: TypeAlias = Callable[[bytes], Tuple[float]]
UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]


class UnpackLen(Protocol):
Expand Down
6 changes: 4 additions & 2 deletions psycopg/psycopg/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def dump_sequence(

return out

def as_literal(self, obj: Any) -> Buffer:
def as_literal(self, obj: Any) -> bytes:
dumper = self.get_dumper(obj, PY_TEXT)
rv = dumper.quote(obj)
# If the result is quoted, and the oid not unknown or text,
Expand All @@ -221,6 +221,8 @@ def as_literal(self, obj: Any) -> Buffer:
if type_sql:
rv = b"%s::%s" % (rv, type_sql)

if not isinstance(rv, bytes):
rv = bytes(rv)
return rv

def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
Expand Down Expand Up @@ -321,7 +323,7 @@ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:

return make_row(record)

def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
Expand Down
8 changes: 4 additions & 4 deletions psycopg/psycopg/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@

# Adaptation types

DumpFunc: TypeAlias = Callable[[Any], bytes]
LoadFunc: TypeAlias = Callable[[bytes], Any]
DumpFunc: TypeAlias = Callable[[Any], Buffer]
LoadFunc: TypeAlias = Callable[[Buffer], Any]


class AdaptContext(Protocol):
Expand Down Expand Up @@ -238,7 +238,7 @@ def dump_sequence(
) -> Sequence[Optional[Buffer]]:
...

def as_literal(self, obj: Any) -> Buffer:
def as_literal(self, obj: Any) -> bytes:
...

def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
Expand All @@ -250,7 +250,7 @@ def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Ro
def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
...

def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
...

def get_loader(self, oid: int, format: pq.Format) -> Loader:
Expand Down
34 changes: 17 additions & 17 deletions psycopg/psycopg/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class QueuedLibpqDriver(LibpqWriter):
def __init__(self, cursor: "Cursor[Any]"):
super().__init__(cursor)

self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
self._worker_error: Optional[BaseException] = None

Expand Down Expand Up @@ -599,7 +599,7 @@ class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)

self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None

async def worker(self) -> None:
Expand Down Expand Up @@ -652,19 +652,19 @@ def __init__(self, transformer: Transformer):
self._row_mode = False # true if the user is using write_row()

@abstractmethod
def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
...

@abstractmethod
def write(self, buffer: Union[Buffer, str]) -> bytes:
def write(self, buffer: Union[Buffer, str]) -> Buffer:
...

@abstractmethod
def write_row(self, row: Sequence[Any]) -> bytes:
def write_row(self, row: Sequence[Any]) -> Buffer:
...

@abstractmethod
def end(self) -> bytes:
def end(self) -> Buffer:
...


Expand All @@ -676,7 +676,7 @@ def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
super().__init__(transformer)
self._encoding = encoding

def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if data:
return parse_row_text(data, self.transformer)
else:
Expand All @@ -687,7 +687,7 @@ def write(self, buffer: Union[Buffer, str]) -> Buffer:
self._signature_sent = True
return data

def write_row(self, row: Sequence[Any]) -> bytes:
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
Expand All @@ -699,7 +699,7 @@ def write_row(self, row: Sequence[Any]) -> bytes:
else:
return b""

def end(self) -> bytes:
def end(self) -> Buffer:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer

Expand All @@ -721,7 +721,7 @@ def __init__(self, transformer: Transformer):
super().__init__(transformer)
self._signature_sent = False

def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if not self._signature_sent:
if data[: len(_binary_signature)] != _binary_signature:
raise e.DataError(
Expand All @@ -740,7 +740,7 @@ def write(self, buffer: Union[Buffer, str]) -> Buffer:
self._signature_sent = True
return data

def write_row(self, row: Sequence[Any]) -> bytes:
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
Expand All @@ -756,7 +756,7 @@ def write_row(self, row: Sequence[Any]) -> bytes:
else:
return b""

def end(self) -> bytes:
def end(self) -> Buffer:
# If we have sent no data we need to send the signature
# and the trailer
if not self._signature_sent:
Expand Down Expand Up @@ -828,17 +828,17 @@ def _format_row_binary(
return out


def _parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
if not isinstance(data, bytes):
data = bytes(data)
fields = data.split(b"\t")
fields[-1] = fields[-1][:-1] # drop \n
fields = data.split(b"\t") # type: ignore
fields[-1] = fields[-1][:-1] # type: ignore # drop \n
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
return tx.load_sequence(row)


def _parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
row: List[Optional[bytes]] = []
def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
row: List[Optional[Buffer]] = []
nfields = _unpack_int2(data, 0)[0]
pos = 2
for i in range(nfields):
Expand Down
4 changes: 2 additions & 2 deletions psycopg/psycopg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from . import pq
from . import errors as e
from .abc import PipelineCommand, PQGen, PQGenConn
from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
from .pq.abc import PGconn, PGresult
from .waiting import Wait, Ready
from ._compat import Deque
Expand Down Expand Up @@ -271,7 +271,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
return result


def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
# Retry enqueuing data until successful.
#
# WARNING! This can cause an infinite loop if the buffer is too large. (see
Expand Down
12 changes: 6 additions & 6 deletions psycopg/psycopg/pq/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def send_query(self, command: bytes) -> None:
def exec_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
Expand All @@ -143,7 +143,7 @@ def exec_params(
def send_query_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
Expand All @@ -161,7 +161,7 @@ def send_prepare(
def send_query_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional[Buffer]]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None:
Expand All @@ -178,7 +178,7 @@ def prepare(
def exec_prepared(
self,
name: bytes,
param_values: Optional[Sequence[bytes]],
param_values: Optional[Sequence[Buffer]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = 0,
) -> "PGresult":
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_cancel(self) -> "PGcancel":
def notifies(self) -> Optional["PGnotify"]:
...

def put_copy_data(self, buffer: bytes) -> int:
def put_copy_data(self, buffer: Buffer) -> int:
...

def put_copy_end(self, error: Optional[bytes] = None) -> int:
Expand Down Expand Up @@ -380,5 +380,5 @@ def escape_string(self, data: Buffer) -> bytes:
def escape_bytea(self, data: Buffer) -> bytes:
...

def unescape_bytea(self, data: bytes) -> bytes:
def unescape_bytea(self, data: Buffer) -> bytes:
...
23 changes: 15 additions & 8 deletions psycopg/psycopg/pq/pq_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def send_query(self, command: bytes) -> None:
def exec_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional["abc.Buffer"]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
Expand All @@ -292,7 +292,7 @@ def exec_params(
def send_query_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional["abc.Buffer"]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
Expand Down Expand Up @@ -329,7 +329,7 @@ def send_prepare(
def send_query_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional["abc.Buffer"]]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None:
Expand All @@ -349,7 +349,7 @@ def send_query_prepared(
def _query_params_args(
self,
command: bytes,
param_values: Optional[Sequence[Optional[bytes]]],
param_values: Optional[Sequence[Optional["abc.Buffer"]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
Expand All @@ -364,7 +364,6 @@ def _query_params_args(
aparams = (c_char_p * nparams)(
*(
# convert bytearray/memoryview to bytes
# TODO: avoid copy, at least in the C implementation.
b
if b is None or isinstance(b, bytes)
else bytes(b) # type: ignore[arg-type]
Expand Down Expand Up @@ -436,7 +435,7 @@ def prepare(
def exec_prepared(
self,
name: bytes,
param_values: Optional[Sequence[bytes]],
param_values: Optional[Sequence["abc.Buffer"]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = 0,
) -> "PGresult":
Expand All @@ -447,7 +446,13 @@ def exec_prepared(
alenghts: Optional[Array[c_int]]
if param_values:
nparams = len(param_values)
aparams = (c_char_p * nparams)(*param_values)
aparams = (c_char_p * nparams)(
*(
# convert bytearray/memoryview to bytes
b if b is None or isinstance(b, bytes) else bytes(b)
for b in param_values
)
)
alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
else:
nparams = 0
Expand Down Expand Up @@ -1050,13 +1055,15 @@ def escape_bytea(self, data: "abc.Buffer") -> bytes:
impl.PQfreemem(out)
return rv

def unescape_bytea(self, data: bytes) -> bytes:
def unescape_bytea(self, data: "abc.Buffer") -> bytes:
# not needed, but let's keep it symmetric with the escaping:
# if a connection is passed in, it must be valid.
if self.conn:
self.conn._ensure_pgconn()

len_out = c_size_t()
if not isinstance(data, bytes):
data = bytes(data)
out = impl.PQunescapeBytea(
data,
byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
Expand Down
8 changes: 4 additions & 4 deletions psycopg/psycopg/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
_unpack_head = cast(Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from)
_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
_struct_dim = struct.Struct("!II") # dim, lower bound
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
_unpack_dim = cast(Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from)
_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)

TEXT_ARRAY_OID = postgres.types["text"].array_oid

Expand Down Expand Up @@ -153,7 +153,7 @@ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
_re_esc = re.compile(rb'(["\\])')

def dump(self, obj: List[Any]) -> bytes:
tokens: List[bytes] = []
tokens: List[Buffer] = []
needs_quotes = _get_needs_quotes_regexp(self.delimiter).search

def dump_list(obj: List[Any]) -> None:
Expand Down Expand Up @@ -249,7 +249,7 @@ def dump(self, obj: List[Any]) -> bytes:
if not obj:
return _pack_head(0, 0, sub_oid)

data: List[bytes] = [b"", b""] # placeholders to avoid a resize
data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
dims: List[int] = []
hasnull = 0

Expand Down

0 comments on commit 5acb614

Please sign in to comment.