From 6a34de62b8dd7066c6ec96ea95aab920fa17e146 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone Date: Sun, 9 Oct 2022 07:48:42 -0400 Subject: [PATCH 1/4] Fix the non-ASCII case --- src/twisted/trial/_dist/stream.py | 35 ++++++++++--------- src/twisted/trial/_dist/test/test_stream.py | 16 ++++----- .../trial/_dist/test/test_workerreporter.py | 11 ++++++ src/twisted/trial/_dist/worker.py | 10 +++--- src/twisted/trial/_dist/workerreporter.py | 10 +++--- src/twisted/trial/test/erroneous.py | 11 ++++++ 6 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/twisted/trial/_dist/stream.py b/src/twisted/trial/_dist/stream.py index 3a87c2ee1d8..a53fd4ab214 100644 --- a/src/twisted/trial/_dist/stream.py +++ b/src/twisted/trial/_dist/stream.py @@ -1,5 +1,5 @@ """ -Buffer string streams +Buffer byte streams. """ from itertools import count @@ -7,7 +7,7 @@ from attrs import Factory, define -from twisted.protocols.amp import AMP, Command, Integer, Unicode +from twisted.protocols.amp import AMP, Command, Integer, String as Bytes T = TypeVar("T") @@ -27,18 +27,18 @@ class StreamWrite(Command): arguments = [ (b"streamId", Integer()), - (b"data", Unicode()), + (b"data", Bytes()), ] @define class StreamReceiver: """ - Buffering de-multiplexing string stream receiver. + Buffering de-multiplexing byte stream receiver. """ _counter: Iterator[int] = count() - _streams: Dict[int, List[str]] = Factory(dict) + _streams: Dict[int, List[bytes]] = Factory(dict) def open(self) -> int: """ @@ -48,34 +48,34 @@ def open(self) -> int: self._streams[newId] = [] return newId - def write(self, streamId: int, chunk: str) -> None: + def write(self, streamId: int, chunk: bytes) -> None: """ Write to an open stream using its unique identifier. - :raise KeyError: If there is no such open stream. + @raise KeyError: If there is no such open stream. """ self._streams[streamId].append(chunk) - def finish(self, streamId: int) -> List[str]: + def finish(self, streamId: int) -> List[bytes]: """ Indicate an open stream may receive no further data and return all of its current contents. - :raise KeyError: If there is no such open stream. + @raise KeyError: If there is no such open stream. """ return self._streams.pop(streamId) -def chunk(data: str, chunkSize: int) -> Iterator[str]: +def chunk(data: bytes, chunkSize: int) -> Iterator[bytes]: """ - Break a string into pieces of no more than ``chunkSize`` length. + Break a byte string into pieces of no more than ``chunkSize`` length. - :param data: The string. + @param data: The byte string. - :param chunkSize: The maximum length of the resulting pieces. All pieces + @param chunkSize: The maximum length of the resulting pieces. All pieces except possibly the last will be this length. - :return: The pieces. + @return: The pieces. """ pos = 0 while pos < len(data): @@ -83,17 +83,18 @@ def chunk(data: str, chunkSize: int) -> Iterator[str]: pos += chunkSize -async def stream(amp: AMP, chunks: Iterator[str]) -> int: +async def stream(amp: AMP, chunks: Iterator[bytes]) -> int: """ Send the given stream chunks, one by one, over the given connection. The chunks are sent using L{StreamWrite} over a stream opened using L{StreamOpen}. - :return: The identifier of the stream over which the chunks were sent. + @return: The identifier of the stream over which the chunks were sent. """ streamId = (await amp.callRemote(StreamOpen))["streamId"] + assert isinstance(streamId, int) for oneChunk in chunks: await amp.callRemote(StreamWrite, streamId=streamId, data=oneChunk) - return streamId # type: ignore[no-any-return] + return streamId diff --git a/src/twisted/trial/_dist/test/test_stream.py b/src/twisted/trial/_dist/test/test_stream.py index 2c506e08a05..08e28b84ccd 100644 --- a/src/twisted/trial/_dist/test/test_stream.py +++ b/src/twisted/trial/_dist/test/test_stream.py @@ -16,7 +16,7 @@ raises, ) from hypothesis import given -from hypothesis.strategies import integers, just, lists, randoms, text +from hypothesis.strategies import integers, just, lists, randoms, text, binary from twisted.internet.defer import Deferred, fail from twisted.internet.interfaces import IProtocol @@ -36,11 +36,11 @@ class StreamReceiverTests(SynchronousTestCase): Tests for L{StreamReceiver} """ - @given(lists(lists(text())), randoms()) - def test_streamReceived(self, streams: List[str], random: Random) -> None: + @given(lists(lists(binary())), randoms()) + def test_streamReceived(self, streams: List[List[bytes]], random: Random) -> None: """ All data passed to L{StreamReceiver.write} is returned by a call to - L{StreamReceiver.finish} with a matching C{streamId} . + L{StreamReceiver.finish} with a matching C{streamId}. """ receiver = StreamReceiver() streamIds = [receiver.open() for _ in streams] @@ -125,7 +125,7 @@ def streamOpen(self) -> Dict[str, object]: return {"streamId": self.streams.open()} @StreamWrite.responder - def streamWrite(self, streamId: int, data: str) -> Dict[str, object]: + def streamWrite(self, streamId: int, data: bytes) -> Dict[str, object]: self.streams.write(streamId, data) return {} @@ -194,13 +194,13 @@ class StreamTests(SynchronousTestCase): Tests for L{stream}. """ - @given(lists(text())) - def test_stream(self, chunks): + @given(lists(binary())) + def test_stream(self, chunks: List[bytes]) -> None: """ All of the chunks passed to L{stream} are sent in order over a stream using the given AMP connection. """ sender = AMP() streams = StreamReceiver() - streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, chunks)) + streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, iter(chunks))) assert_that(streams.finish(streamId), is_(equal_to(chunks))) diff --git a/src/twisted/trial/_dist/test/test_workerreporter.py b/src/twisted/trial/_dist/test/test_workerreporter.py index 42f8480bb09..7096ee9e088 100644 --- a/src/twisted/trial/_dist/test/test_workerreporter.py +++ b/src/twisted/trial/_dist/test/test_workerreporter.py @@ -81,6 +81,17 @@ def test_addErrorGreaterThan64k(self) -> None: errors=has_length(1), ) + def test_addErrorGreaterThan64kEncoded(self) -> None: + """ + L{WorkerReporter} propagates errors with a string representation that + is smaller than an implementation-specific limit but which encode to a + byte representation that exceeds this limit. + """ + self.assertTestRun( + erroneous.TestAsynchronousFail("test_exceptionGreaterThan64kEncoded"), + errors=has_length(1), + ) + def test_addErrorTuple(self) -> None: """ L{WorkerReporter} propagates errors from pyunit's TestCases. diff --git a/src/twisted/trial/_dist/worker.py b/src/twisted/trial/_dist/worker.py index 0011169fc96..ff413bb6ca5 100644 --- a/src/twisted/trial/_dist/worker.py +++ b/src/twisted/trial/_dist/worker.py @@ -209,8 +209,8 @@ def addError( @param error: A message describing the error. """ - error = "".join(self._streams.finish(errorStreamId)) - frames = self._streams.finish(framesStreamId) + error = b"".join(self._streams.finish(errorStreamId)).decode("utf-8") + frames = [frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)] # Wrap the error message in ``WorkerException`` because it is not # possible to transfer arbitrary exception values over the AMP # connection to the main process but we must give *some* Exception @@ -237,8 +237,8 @@ def addFailure( of the traceback for this error were previously completely sent to the peer. """ - fail = "".join(self._streams.finish(failStreamId)) - frames = self._streams.finish(framesStreamId) + fail = b"".join(self._streams.finish(failStreamId)).decode("utf-8") + frames = [frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)] # See addError for info about use of WorkerException here. failure = self._buildFailure(WorkerException(fail), failClass, frames) self._result.addFailure(self._testCase, failure) @@ -262,7 +262,7 @@ def addExpectedFailure( @param errorStreamId: The identifier of a stream over which the text of this error was previously completely sent to the peer. """ - error = "".join(self._streams.finish(errorStreamId)) + error = b"".join(self._streams.finish(errorStreamId)).decode("utf-8") _todo = Todo("" if todo is None else todo) self._result.addExpectedFailure(self._testCase, error, _todo) return {"success": True} diff --git a/src/twisted/trial/_dist/workerreporter.py b/src/twisted/trial/_dist/workerreporter.py index 3af15100456..637ae6f6100 100644 --- a/src/twisted/trial/_dist/workerreporter.py +++ b/src/twisted/trial/_dist/workerreporter.py @@ -45,8 +45,8 @@ async def addError( :param frames: The lines of the traceback associated with the error. """ - errorStreamId = await stream(amp, chunk(error, 2 ** 16 - 1)) - framesStreamId = await stream(amp, iter(frames)) + errorStreamId = await stream(amp, chunk(error.encode("utf-8"), 2 ** 16 - 1)) + framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames)) await amp.callRemote( managercommands.AddError, @@ -70,8 +70,8 @@ async def addFailure( :param fail: The string representation of the failure. :param frames: The lines of the traceback associated with the error. """ - failStreamId = await stream(amp, chunk(fail, 2 ** 16 - 1)) - framesStreamId = await stream(amp, iter(frames)) + failStreamId = await stream(amp, chunk(fail.encode("utf-8"), 2 ** 16 - 1)) + framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames)) await amp.callRemote( managercommands.AddFailure, @@ -91,7 +91,7 @@ async def addExpectedFailure(amp: AMP, testName: str, error: str, todo: str) -> :param error: The string representation of the expected failure. :param todo: The string description of the expectation. """ - errorStreamId = await stream(amp, chunk(error, 2 ** 16 - 1)) + errorStreamId = await stream(amp, chunk(error.encode("utf-8"), 2 ** 16 - 1)) await amp.callRemote( managercommands.AddExpectedFailure, diff --git a/src/twisted/trial/test/erroneous.py b/src/twisted/trial/test/erroneous.py index fcf04558d1f..f93ac343b74 100644 --- a/src/twisted/trial/test/erroneous.py +++ b/src/twisted/trial/test/erroneous.py @@ -146,6 +146,17 @@ def test_exceptionGreaterThan64k(self) -> None: """ raise LargeError(2 ** 16) + def test_exceptionGreaterThan64kEncoded(self) -> None: + """ + A test which synchronously raises an exception with a long string + representation including non-ascii content. + """ + # The exception text itself is not greater than 64k but SNOWMAN + # encodes to 3 bytes with UTF-8 so the length of the UTF-8 encoding of + # the string representation of this exception will be greater than 2 + # ** 16. + raise Exception("\N{SNOWMAN}" * 2 ** 15) + class ErrorTest(unittest.SynchronousTestCase): """ From b23a961ad46f20f8308bdaa1ba46f47323e09066 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone Date: Sun, 9 Oct 2022 07:48:54 -0400 Subject: [PATCH 2/4] news fragment --- src/twisted/trial/newsfragments/11710.misc | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/twisted/trial/newsfragments/11710.misc diff --git a/src/twisted/trial/newsfragments/11710.misc b/src/twisted/trial/newsfragments/11710.misc new file mode 100644 index 00000000000..e69de29bb2d From 2a0b4b19fb63b3748fb32fe2c71d507eb401a459 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone Date: Sun, 9 Oct 2022 07:49:18 -0400 Subject: [PATCH 3/4] formatting --- src/twisted/trial/_dist/test/test_stream.py | 6 ++++-- src/twisted/trial/_dist/worker.py | 8 ++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/twisted/trial/_dist/test/test_stream.py b/src/twisted/trial/_dist/test/test_stream.py index 08e28b84ccd..4441ada3bcc 100644 --- a/src/twisted/trial/_dist/test/test_stream.py +++ b/src/twisted/trial/_dist/test/test_stream.py @@ -16,7 +16,7 @@ raises, ) from hypothesis import given -from hypothesis.strategies import integers, just, lists, randoms, text, binary +from hypothesis.strategies import binary, integers, just, lists, randoms, text from twisted.internet.defer import Deferred, fail from twisted.internet.interfaces import IProtocol @@ -202,5 +202,7 @@ def test_stream(self, chunks: List[bytes]) -> None: """ sender = AMP() streams = StreamReceiver() - streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, iter(chunks))) + streamId = interact( + AMPStreamReceiver(streams), sender, stream(sender, iter(chunks)) + ) assert_that(streams.finish(streamId), is_(equal_to(chunks))) diff --git a/src/twisted/trial/_dist/worker.py b/src/twisted/trial/_dist/worker.py index ff413bb6ca5..04c39fc8551 100644 --- a/src/twisted/trial/_dist/worker.py +++ b/src/twisted/trial/_dist/worker.py @@ -210,7 +210,9 @@ def addError( @param error: A message describing the error. """ error = b"".join(self._streams.finish(errorStreamId)).decode("utf-8") - frames = [frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)] + frames = [ + frame.decode("utf-8") for frame in self._streams.finish(framesStreamId) + ] # Wrap the error message in ``WorkerException`` because it is not # possible to transfer arbitrary exception values over the AMP # connection to the main process but we must give *some* Exception @@ -238,7 +240,9 @@ def addFailure( peer. """ fail = b"".join(self._streams.finish(failStreamId)).decode("utf-8") - frames = [frame.decode("utf-8") for frame in self._streams.finish(framesStreamId)] + frames = [ + frame.decode("utf-8") for frame in self._streams.finish(framesStreamId) + ] # See addError for info about use of WorkerException here. failure = self._buildFailure(WorkerException(fail), failClass, frames) self._result.addFailure(self._testCase, failure) From ce7252ffea7295cc081d0fa3b17b9c5e31be9bb6 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone Date: Tue, 11 Oct 2022 16:43:54 -0400 Subject: [PATCH 4/4] Use the AMP-defined limit instead of inventing our own. --- src/twisted/trial/_dist/workerreporter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twisted/trial/_dist/workerreporter.py b/src/twisted/trial/_dist/workerreporter.py index 637ae6f6100..436f6f496a3 100644 --- a/src/twisted/trial/_dist/workerreporter.py +++ b/src/twisted/trial/_dist/workerreporter.py @@ -17,7 +17,7 @@ from typing_extensions import Literal from twisted.internet.defer import Deferred, maybeDeferred -from twisted.protocols.amp import AMP +from twisted.protocols.amp import AMP, MAX_VALUE_LENGTH from twisted.python.failure import Failure from twisted.python.reflect import qual from twisted.trial._dist import managercommands @@ -45,7 +45,7 @@ async def addError( :param frames: The lines of the traceback associated with the error. """ - errorStreamId = await stream(amp, chunk(error.encode("utf-8"), 2 ** 16 - 1)) + errorStreamId = await stream(amp, chunk(error.encode("utf-8"), MAX_VALUE_LENGTH)) framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames)) await amp.callRemote( @@ -70,7 +70,7 @@ async def addFailure( :param fail: The string representation of the failure. :param frames: The lines of the traceback associated with the error. """ - failStreamId = await stream(amp, chunk(fail.encode("utf-8"), 2 ** 16 - 1)) + failStreamId = await stream(amp, chunk(fail.encode("utf-8"), MAX_VALUE_LENGTH)) framesStreamId = await stream(amp, (frame.encode("utf-8") for frame in frames)) await amp.callRemote( @@ -91,7 +91,7 @@ async def addExpectedFailure(amp: AMP, testName: str, error: str, todo: str) -> :param error: The string representation of the expected failure. :param todo: The string description of the expectation. """ - errorStreamId = await stream(amp, chunk(error.encode("utf-8"), 2 ** 16 - 1)) + errorStreamId = await stream(amp, chunk(error.encode("utf-8"), MAX_VALUE_LENGTH)) await amp.callRemote( managercommands.AddExpectedFailure,