Skip to content

Commit

Permalink
#11661 Add an AMP stream helper for disttrial (#11662)
Browse files Browse the repository at this point in the history
  • Loading branch information
exarkun committed Sep 13, 2022
2 parents 7c00738 + 18aeac0 commit 398b603
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 0 deletions.
99 changes: 99 additions & 0 deletions src/twisted/trial/_dist/stream.py
@@ -0,0 +1,99 @@
"""
Buffer string streams
"""

from itertools import count
from typing import Dict, Iterator, List, TypeVar

from attrs import Factory, define

from twisted.protocols.amp import AMP, Command, Integer, Unicode

T = TypeVar("T")


class StreamOpen(Command):
"""
Open a new stream.
"""

response = [(b"streamId", Integer())]


class StreamWrite(Command):
"""
Write a chunk of data to a stream.
"""

arguments = [
(b"streamId", Integer()),
(b"data", Unicode()),
]


@define
class StreamReceiver:
"""
Buffering de-multiplexing string stream receiver.
"""

_counter: Iterator[int] = count()
_streams: Dict[int, List[str]] = Factory(dict)

def open(self) -> int:
"""
Open a new stream and return its unique identifier.
"""
newId = next(self._counter)
self._streams[newId] = []
return newId

def write(self, streamId: int, chunk: str) -> None:
"""
Write to an open stream using its unique identifier.
:raise KeyError: If there is no such open stream.
"""
self._streams[streamId].append(chunk)

def finish(self, streamId: int) -> List[str]:
"""
Indicate an open stream may receive no further data and return all of
its current contents.
:raise KeyError: If there is no such open stream.
"""
return self._streams.pop(streamId)


def chunk(data: str, chunkSize: int) -> Iterator[str]:
"""
Break a string into pieces of no more than ``chunkSize`` length.
:param data: The string.
:param chunkSize: The maximum length of the resulting pieces. All pieces
except possibly the last will be this length.
:return: The pieces.
"""
pos = 0
while pos < len(data):
yield data[pos : pos + chunkSize]
pos += chunkSize


async def stream(amp: AMP, chunks: Iterator[str]) -> int:
"""
Send the given stream chunks, one by one, over the given connection.
The chunks are sent using L{StreamWrite} over a stream opened using
L{StreamOpen}.
:return: The identifier of the stream over which the chunks were sent.
"""
streamId = (await amp.callRemote(StreamOpen))["streamId"]

for oneChunk in chunks:
await amp.callRemote(StreamWrite, streamId=streamId, data=oneChunk)
return streamId # type: ignore[no-any-return]
206 changes: 206 additions & 0 deletions src/twisted/trial/_dist/test/test_stream.py
@@ -0,0 +1,206 @@
"""
Tests for L{twisted.trial._dist.stream}.
"""

from random import Random
from typing import Awaitable, Dict, List, TypeVar, Union

from hamcrest import (
all_of,
assert_that,
calling,
equal_to,
has_length,
is_,
less_than_or_equal_to,
raises,
)
from hypothesis import given
from hypothesis.strategies import integers, just, lists, randoms, text

from twisted.internet.defer import Deferred, fail
from twisted.internet.interfaces import IProtocol
from twisted.internet.protocol import Protocol
from twisted.protocols.amp import AMP
from twisted.python.failure import Failure
from twisted.test.iosim import FakeTransport, connect
from twisted.trial.unittest import SynchronousTestCase
from ..stream import StreamOpen, StreamReceiver, StreamWrite, chunk, stream
from .matchers import HasSum, IsSequenceOf

T = TypeVar("T")


class StreamReceiverTests(SynchronousTestCase):
"""
Tests for L{StreamReceiver}
"""

@given(lists(lists(text())), randoms())
def test_streamReceived(self, streams: List[str], random: Random) -> None:
"""
All data passed to L{StreamReceiver.write} is returned by a call to
L{StreamReceiver.finish} with a matching C{streamId} .
"""
receiver = StreamReceiver()
streamIds = [receiver.open() for _ in streams]

# uncorrelate the results with open() order
random.shuffle(streamIds)

expectedData = dict(zip(streamIds, streams))
for streamId, strings in expectedData.items():
for s in strings:
receiver.write(streamId, s)

# uncorrelate the results with write() order
random.shuffle(streamIds)

actualData = {streamId: receiver.finish(streamId) for streamId in streamIds}

assert_that(actualData, is_(equal_to(expectedData)))

@given(integers(), just("data"))
def test_writeBadStreamId(self, streamId: int, data: str) -> None:
"""
L{StreamReceiver.write} raises L{KeyError} if called with a
streamId not associated with an open stream.
"""
receiver = StreamReceiver()
assert_that(calling(receiver.write).with_args(streamId, data), raises(KeyError))

@given(integers())
def test_badFinishStreamId(self, streamId: int) -> None:
"""
L{StreamReceiver.finish} raises L{KeyError} if called with a
streamId not associated with an open stream.
"""
receiver = StreamReceiver()
assert_that(calling(receiver.finish).with_args(streamId), raises(KeyError))

def test_finishRemovesStream(self) -> None:
"""
L{StreamReceiver.finish} removes the identified stream.
"""
receiver = StreamReceiver()
streamId = receiver.open()
receiver.finish(streamId)
assert_that(calling(receiver.finish).with_args(streamId), raises(KeyError))


class ChunkTests(SynchronousTestCase):
"""
Tests for ``chunk``.
"""

@given(data=text(), chunkSize=integers(min_value=1))
def test_chunk(self, data, chunkSize):
"""
L{chunk} returns an iterable of L{str} where each element is no
longer than the given limit. The concatenation of the strings is also
equal to the original input string.
"""
chunks = list(chunk(data, chunkSize))
assert_that(
chunks,
all_of(
IsSequenceOf(
has_length(less_than_or_equal_to(chunkSize)),
),
HasSum(equal_to(data), data[:0]),
),
)


class AMPStreamReceiver(AMP):
"""
A simple AMP interface to L{StreamReceiver}.
"""

def __init__(self, streams: StreamReceiver) -> None:
self.streams = streams

@StreamOpen.responder
def streamOpen(self) -> Dict[str, object]:
return {"streamId": self.streams.open()}

@StreamWrite.responder
def streamWrite(self, streamId: int, data: str) -> Dict[str, object]:
self.streams.write(streamId, data)
return {}


def interact(server: IProtocol, client: IProtocol, interaction: Awaitable[T]) -> T:
"""
Let C{server} and C{client} exchange bytes while C{interaction} runs.
"""
finished = False
result: Union[Failure, T]

async def to_coroutine() -> T:
return await interaction

def collect_result(r: Union[Failure, T]) -> None:
nonlocal result, finished
finished = True
result = r

pump = connect(
server,
FakeTransport(server, isServer=True),
client,
FakeTransport(client, isServer=False),
)
interacting = Deferred.fromCoroutine(to_coroutine())
interacting.addBoth(collect_result)

pump.flush()

if finished:
if isinstance(result, Failure):
result.raiseException()
return result
raise Exception("Interaction failed to produce a result.")


class InteractTests(SynchronousTestCase):
"""
Tests for the test helper L{interact}.
"""

def test_failure(self):
"""
If the interaction results in a failure then L{interact} raises an
exception.
"""

class ArbitraryException(Exception):
pass

with self.assertRaises(ArbitraryException):
interact(Protocol(), Protocol(), fail(ArbitraryException()))

def test_incomplete(self):
"""
If the interaction fails to produce a result then L{interact} raises
an exception.
"""
with self.assertRaises(Exception):
interact(Protocol(), Protocol(), Deferred())


class StreamTests(SynchronousTestCase):
"""
Tests for L{stream}.
"""

@given(lists(text()))
def test_stream(self, chunks):
"""
All of the chunks passed to L{stream} are sent in order over a
stream using the given AMP connection.
"""
sender = AMP()
streams = StreamReceiver()
streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, chunks))
assert_that(streams.finish(streamId), is_(equal_to(chunks)))
Empty file.

0 comments on commit 398b603

Please sign in to comment.