Skip to content

Commit

Permalink
#11710 Fix more AMP errors in disttrial (#11711)
Browse files Browse the repository at this point in the history
Fix more cases where disttrial has problems sending test results
back to the manager process.  Specifically, chunk strings containing
non-ASCII characters into sizes legal for transport by AMP.
  • Loading branch information
exarkun committed Oct 11, 2022
2 parents ca4ca07 + ce7252f commit 915c248
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 36 deletions.
35 changes: 18 additions & 17 deletions src/twisted/trial/_dist/stream.py
@@ -1,13 +1,13 @@
"""
Buffer string streams
Buffer byte streams.
"""

from itertools import count
from typing import Dict, Iterator, List, TypeVar

from attrs import Factory, define

from twisted.protocols.amp import AMP, Command, Integer, Unicode
from twisted.protocols.amp import AMP, Command, Integer, String as Bytes

T = TypeVar("T")

Expand All @@ -27,18 +27,18 @@ class StreamWrite(Command):

arguments = [
(b"streamId", Integer()),
(b"data", Unicode()),
(b"data", Bytes()),
]


@define
class StreamReceiver:
"""
Buffering de-multiplexing string stream receiver.
Buffering de-multiplexing byte stream receiver.
"""

_counter: Iterator[int] = count()
_streams: Dict[int, List[str]] = Factory(dict)
_streams: Dict[int, List[bytes]] = Factory(dict)

def open(self) -> int:
"""
Expand All @@ -48,52 +48,53 @@ def open(self) -> int:
self._streams[newId] = []
return newId

def write(self, streamId: int, chunk: str) -> None:
def write(self, streamId: int, chunk: bytes) -> None:
"""
Write to an open stream using its unique identifier.
:raise KeyError: If there is no such open stream.
@raise KeyError: If there is no such open stream.
"""
self._streams[streamId].append(chunk)

def finish(self, streamId: int) -> List[str]:
def finish(self, streamId: int) -> List[bytes]:
"""
Indicate an open stream may receive no further data and return all of
its current contents.
:raise KeyError: If there is no such open stream.
@raise KeyError: If there is no such open stream.
"""
return self._streams.pop(streamId)


def chunk(data: str, chunkSize: int) -> Iterator[str]:
def chunk(data: bytes, chunkSize: int) -> Iterator[bytes]:
"""
Break a string into pieces of no more than ``chunkSize`` length.
Break a byte string into pieces of no more than ``chunkSize`` length.
:param data: The string.
@param data: The byte string.
:param chunkSize: The maximum length of the resulting pieces. All pieces
@param chunkSize: The maximum length of the resulting pieces. All pieces
except possibly the last will be this length.
:return: The pieces.
@return: The pieces.
"""
pos = 0
while pos < len(data):
yield data[pos : pos + chunkSize]
pos += chunkSize


async def stream(amp: AMP, chunks: Iterator[str]) -> int:
async def stream(amp: AMP, chunks: Iterator[bytes]) -> int:
"""
Send the given stream chunks, one by one, over the given connection.
The chunks are sent using L{StreamWrite} over a stream opened using
L{StreamOpen}.
:return: The identifier of the stream over which the chunks were sent.
@return: The identifier of the stream over which the chunks were sent.
"""
streamId = (await amp.callRemote(StreamOpen))["streamId"]
assert isinstance(streamId, int)

for oneChunk in chunks:
await amp.callRemote(StreamWrite, streamId=streamId, data=oneChunk)
return streamId # type: ignore[no-any-return]
return streamId
18 changes: 10 additions & 8 deletions src/twisted/trial/_dist/test/test_stream.py
Expand Up @@ -16,7 +16,7 @@
raises,
)
from hypothesis import given
from hypothesis.strategies import integers, just, lists, randoms, text
from hypothesis.strategies import binary, integers, just, lists, randoms, text

from twisted.internet.defer import Deferred, fail
from twisted.internet.interfaces import IProtocol
Expand All @@ -36,11 +36,11 @@ class StreamReceiverTests(SynchronousTestCase):
Tests for L{StreamReceiver}
"""

@given(lists(lists(text())), randoms())
def test_streamReceived(self, streams: List[str], random: Random) -> None:
@given(lists(lists(binary())), randoms())
def test_streamReceived(self, streams: List[List[bytes]], random: Random) -> None:
"""
All data passed to L{StreamReceiver.write} is returned by a call to
L{StreamReceiver.finish} with a matching C{streamId} .
L{StreamReceiver.finish} with a matching C{streamId}.
"""
receiver = StreamReceiver()
streamIds = [receiver.open() for _ in streams]
Expand Down Expand Up @@ -125,7 +125,7 @@ def streamOpen(self) -> Dict[str, object]:
return {"streamId": self.streams.open()}

@StreamWrite.responder
def streamWrite(self, streamId: int, data: str) -> Dict[str, object]:
def streamWrite(self, streamId: int, data: bytes) -> Dict[str, object]:
self.streams.write(streamId, data)
return {}

Expand Down Expand Up @@ -194,13 +194,15 @@ class StreamTests(SynchronousTestCase):
Tests for L{stream}.
"""

@given(lists(text()))
def test_stream(self, chunks):
@given(lists(binary()))
def test_stream(self, chunks: List[bytes]) -> None:
"""
All of the chunks passed to L{stream} are sent in order over a
stream using the given AMP connection.
"""
sender = AMP()
streams = StreamReceiver()
streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, chunks))
streamId = interact(
AMPStreamReceiver(streams), sender, stream(sender, iter(chunks))
)
assert_that(streams.finish(streamId), is_(equal_to(chunks)))
11 changes: 11 additions & 0 deletions src/twisted/trial/_dist/test/test_workerreporter.py
Expand Up @@ -81,6 +81,17 @@ def test_addErrorGreaterThan64k(self) -> None:
errors=has_length(1),
)

def test_addErrorGreaterThan64kEncoded(self) -> None:
"""
L{WorkerReporter} propagates errors with a string representation that
is smaller than an implementation-specific limit but which encode to a
byte representation that exceeds this limit.
"""
self.assertTestRun(
erroneous.TestAsynchronousFail("test_exceptionGreaterThan64kEncoded"),
errors=has_length(1),
)

def test_addErrorTuple(self) -> None:
"""
L{WorkerReporter} propagates errors from pyunit's TestCases.
Expand Down
14 changes: 9 additions & 5 deletions src/twisted/trial/_dist/worker.py
Expand Up @@ -209,8 +209,10 @@ def addError(
@param error: A message describing the error.
"""
error = "".join(self._streams.finish(errorStreamId))
frames = self._streams.finish(framesStreamId)
error = b"".join(self._streams.finish(errorStreamId)).decode("utf-8")
frames = [
frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)
]
# Wrap the error message in ``WorkerException`` because it is not
# possible to transfer arbitrary exception values over the AMP
# connection to the main process but we must give *some* Exception
Expand All @@ -237,8 +239,10 @@ def addFailure(
of the traceback for this error were previously completely sent to the
peer.
"""
fail = "".join(self._streams.finish(failStreamId))
frames = self._streams.finish(framesStreamId)
fail = b"".join(self._streams.finish(failStreamId)).decode("utf-8")
frames = [
frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)
]
# See addError for info about use of WorkerException here.
failure = self._buildFailure(WorkerException(fail), failClass, frames)
self._result.addFailure(self._testCase, failure)
Expand All @@ -262,7 +266,7 @@ def addExpectedFailure(
@param errorStreamId: The identifier of a stream over which the text
of this error was previously completely sent to the peer.
"""
error = "".join(self._streams.finish(errorStreamId))
error = b"".join(self._streams.finish(errorStreamId)).decode("utf-8")
_todo = Todo("<unknown>" if todo is None else todo)
self._result.addExpectedFailure(self._testCase, error, _todo)
return {"success": True}
Expand Down
12 changes: 6 additions & 6 deletions src/twisted/trial/_dist/workerreporter.py
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import Literal

from twisted.internet.defer import Deferred, maybeDeferred
from twisted.protocols.amp import AMP
from twisted.protocols.amp import AMP, MAX_VALUE_LENGTH
from twisted.python.failure import Failure
from twisted.python.reflect import qual
from twisted.trial._dist import managercommands
Expand Down Expand Up @@ -45,8 +45,8 @@ async def addError(
:param frames: The lines of the traceback associated with the error.
"""

errorStreamId = await stream(amp, chunk(error, 2 ** 16 - 1))
framesStreamId = await stream(amp, iter(frames))
errorStreamId = await stream(amp, chunk(error.encode("utf-8"), MAX_VALUE_LENGTH))
framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames))

await amp.callRemote(
managercommands.AddError,
Expand All @@ -70,8 +70,8 @@ async def addFailure(
:param fail: The string representation of the failure.
:param frames: The lines of the traceback associated with the error.
"""
failStreamId = await stream(amp, chunk(fail, 2 ** 16 - 1))
framesStreamId = await stream(amp, iter(frames))
failStreamId = await stream(amp, chunk(fail.encode("utf-8"), MAX_VALUE_LENGTH))
framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames))

await amp.callRemote(
managercommands.AddFailure,
Expand All @@ -91,7 +91,7 @@ async def addExpectedFailure(amp: AMP, testName: str, error: str, todo: str) ->
:param error: The string representation of the expected failure.
:param todo: The string description of the expectation.
"""
errorStreamId = await stream(amp, chunk(error, 2 ** 16 - 1))
errorStreamId = await stream(amp, chunk(error.encode("utf-8"), MAX_VALUE_LENGTH))

await amp.callRemote(
managercommands.AddExpectedFailure,
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions src/twisted/trial/test/erroneous.py
Expand Up @@ -146,6 +146,17 @@ def test_exceptionGreaterThan64k(self) -> None:
"""
raise LargeError(2 ** 16)

def test_exceptionGreaterThan64kEncoded(self) -> None:
"""
A test which synchronously raises an exception with a long string
representation including non-ascii content.
"""
# The exception text itself is not greater than 64k but SNOWMAN
# encodes to 3 bytes with UTF-8 so the length of the UTF-8 encoding of
# the string representation of this exception will be greater than 2
# ** 16.
raise Exception("\N{SNOWMAN}" * 2 ** 15)


class ErrorTest(unittest.SynchronousTestCase):
"""
Expand Down

0 comments on commit 915c248

Please sign in to comment.