Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
3 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
""" | ||
Buffer string streams | ||
""" | ||
|
||
from itertools import count | ||
from typing import Dict, Iterator, List, TypeVar | ||
|
||
from attrs import Factory, define | ||
|
||
from twisted.protocols.amp import AMP, Command, Integer, Unicode | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class StreamOpen(Command): | ||
""" | ||
Open a new stream. | ||
""" | ||
|
||
response = [(b"streamId", Integer())] | ||
|
||
|
||
class StreamWrite(Command): | ||
""" | ||
Write a chunk of data to a stream. | ||
""" | ||
|
||
arguments = [ | ||
(b"streamId", Integer()), | ||
(b"data", Unicode()), | ||
] | ||
|
||
|
||
@define | ||
class StreamReceiver: | ||
""" | ||
Buffering de-multiplexing string stream receiver. | ||
""" | ||
|
||
_counter: Iterator[int] = count() | ||
_streams: Dict[int, List[str]] = Factory(dict) | ||
|
||
def open(self) -> int: | ||
""" | ||
Open a new stream and return its unique identifier. | ||
""" | ||
newId = next(self._counter) | ||
self._streams[newId] = [] | ||
return newId | ||
|
||
def write(self, streamId: int, chunk: str) -> None: | ||
""" | ||
Write to an open stream using its unique identifier. | ||
:raise KeyError: If there is no such open stream. | ||
""" | ||
self._streams[streamId].append(chunk) | ||
|
||
def finish(self, streamId: int) -> List[str]: | ||
""" | ||
Indicate an open stream may receive no further data and return all of | ||
its current contents. | ||
:raise KeyError: If there is no such open stream. | ||
""" | ||
return self._streams.pop(streamId) | ||
|
||
|
||
def chunk(data: str, chunkSize: int) -> Iterator[str]: | ||
""" | ||
Break a string into pieces of no more than ``chunkSize`` length. | ||
:param data: The string. | ||
:param chunkSize: The maximum length of the resulting pieces. All pieces | ||
except possibly the last will be this length. | ||
:return: The pieces. | ||
""" | ||
pos = 0 | ||
while pos < len(data): | ||
yield data[pos : pos + chunkSize] | ||
pos += chunkSize | ||
|
||
|
||
async def stream(amp: AMP, chunks: Iterator[str]) -> int: | ||
""" | ||
Send the given stream chunks, one by one, over the given connection. | ||
The chunks are sent using L{StreamWrite} over a stream opened using | ||
L{StreamOpen}. | ||
:return: The identifier of the stream over which the chunks were sent. | ||
""" | ||
streamId = (await amp.callRemote(StreamOpen))["streamId"] | ||
|
||
for oneChunk in chunks: | ||
await amp.callRemote(StreamWrite, streamId=streamId, data=oneChunk) | ||
return streamId # type: ignore[no-any-return] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
""" | ||
Tests for L{twisted.trial._dist.stream}. | ||
""" | ||
|
||
from random import Random | ||
from typing import Awaitable, Dict, List, TypeVar, Union | ||
|
||
from hamcrest import ( | ||
all_of, | ||
assert_that, | ||
calling, | ||
equal_to, | ||
has_length, | ||
is_, | ||
less_than_or_equal_to, | ||
raises, | ||
) | ||
from hypothesis import given | ||
from hypothesis.strategies import integers, just, lists, randoms, text | ||
|
||
from twisted.internet.defer import Deferred, fail | ||
from twisted.internet.interfaces import IProtocol | ||
from twisted.internet.protocol import Protocol | ||
from twisted.protocols.amp import AMP | ||
from twisted.python.failure import Failure | ||
from twisted.test.iosim import FakeTransport, connect | ||
from twisted.trial.unittest import SynchronousTestCase | ||
from ..stream import StreamOpen, StreamReceiver, StreamWrite, chunk, stream | ||
from .matchers import HasSum, IsSequenceOf | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class StreamReceiverTests(SynchronousTestCase): | ||
""" | ||
Tests for L{StreamReceiver} | ||
""" | ||
|
||
@given(lists(lists(text())), randoms()) | ||
def test_streamReceived(self, streams: List[str], random: Random) -> None: | ||
""" | ||
All data passed to L{StreamReceiver.write} is returned by a call to | ||
L{StreamReceiver.finish} with a matching C{streamId} . | ||
""" | ||
receiver = StreamReceiver() | ||
streamIds = [receiver.open() for _ in streams] | ||
|
||
# uncorrelate the results with open() order | ||
random.shuffle(streamIds) | ||
|
||
expectedData = dict(zip(streamIds, streams)) | ||
for streamId, strings in expectedData.items(): | ||
for s in strings: | ||
receiver.write(streamId, s) | ||
|
||
# uncorrelate the results with write() order | ||
random.shuffle(streamIds) | ||
|
||
actualData = {streamId: receiver.finish(streamId) for streamId in streamIds} | ||
|
||
assert_that(actualData, is_(equal_to(expectedData))) | ||
|
||
@given(integers(), just("data")) | ||
def test_writeBadStreamId(self, streamId: int, data: str) -> None: | ||
""" | ||
L{StreamReceiver.write} raises L{KeyError} if called with a | ||
streamId not associated with an open stream. | ||
""" | ||
receiver = StreamReceiver() | ||
assert_that(calling(receiver.write).with_args(streamId, data), raises(KeyError)) | ||
|
||
@given(integers()) | ||
def test_badFinishStreamId(self, streamId: int) -> None: | ||
""" | ||
L{StreamReceiver.finish} raises L{KeyError} if called with a | ||
streamId not associated with an open stream. | ||
""" | ||
receiver = StreamReceiver() | ||
assert_that(calling(receiver.finish).with_args(streamId), raises(KeyError)) | ||
|
||
def test_finishRemovesStream(self) -> None: | ||
""" | ||
L{StreamReceiver.finish} removes the identified stream. | ||
""" | ||
receiver = StreamReceiver() | ||
streamId = receiver.open() | ||
receiver.finish(streamId) | ||
assert_that(calling(receiver.finish).with_args(streamId), raises(KeyError)) | ||
|
||
|
||
class ChunkTests(SynchronousTestCase): | ||
""" | ||
Tests for ``chunk``. | ||
""" | ||
|
||
@given(data=text(), chunkSize=integers(min_value=1)) | ||
def test_chunk(self, data, chunkSize): | ||
""" | ||
L{chunk} returns an iterable of L{str} where each element is no | ||
longer than the given limit. The concatenation of the strings is also | ||
equal to the original input string. | ||
""" | ||
chunks = list(chunk(data, chunkSize)) | ||
assert_that( | ||
chunks, | ||
all_of( | ||
IsSequenceOf( | ||
has_length(less_than_or_equal_to(chunkSize)), | ||
), | ||
HasSum(equal_to(data), data[:0]), | ||
), | ||
) | ||
|
||
|
||
class AMPStreamReceiver(AMP): | ||
""" | ||
A simple AMP interface to L{StreamReceiver}. | ||
""" | ||
|
||
def __init__(self, streams: StreamReceiver) -> None: | ||
self.streams = streams | ||
|
||
@StreamOpen.responder | ||
def streamOpen(self) -> Dict[str, object]: | ||
return {"streamId": self.streams.open()} | ||
|
||
@StreamWrite.responder | ||
def streamWrite(self, streamId: int, data: str) -> Dict[str, object]: | ||
self.streams.write(streamId, data) | ||
return {} | ||
|
||
|
||
def interact(server: IProtocol, client: IProtocol, interaction: Awaitable[T]) -> T: | ||
""" | ||
Let C{server} and C{client} exchange bytes while C{interaction} runs. | ||
""" | ||
finished = False | ||
result: Union[Failure, T] | ||
|
||
async def to_coroutine() -> T: | ||
return await interaction | ||
|
||
def collect_result(r: Union[Failure, T]) -> None: | ||
nonlocal result, finished | ||
finished = True | ||
result = r | ||
|
||
pump = connect( | ||
server, | ||
FakeTransport(server, isServer=True), | ||
client, | ||
FakeTransport(client, isServer=False), | ||
) | ||
interacting = Deferred.fromCoroutine(to_coroutine()) | ||
interacting.addBoth(collect_result) | ||
|
||
pump.flush() | ||
|
||
if finished: | ||
if isinstance(result, Failure): | ||
result.raiseException() | ||
return result | ||
raise Exception("Interaction failed to produce a result.") | ||
|
||
|
||
class InteractTests(SynchronousTestCase): | ||
""" | ||
Tests for the test helper L{interact}. | ||
""" | ||
|
||
def test_failure(self): | ||
""" | ||
If the interaction results in a failure then L{interact} raises an | ||
exception. | ||
""" | ||
|
||
class ArbitraryException(Exception): | ||
pass | ||
|
||
with self.assertRaises(ArbitraryException): | ||
interact(Protocol(), Protocol(), fail(ArbitraryException())) | ||
|
||
def test_incomplete(self): | ||
""" | ||
If the interaction fails to produce a result then L{interact} raises | ||
an exception. | ||
""" | ||
with self.assertRaises(Exception): | ||
interact(Protocol(), Protocol(), Deferred()) | ||
|
||
|
||
class StreamTests(SynchronousTestCase): | ||
""" | ||
Tests for L{stream}. | ||
""" | ||
|
||
@given(lists(text())) | ||
def test_stream(self, chunks): | ||
""" | ||
All of the chunks passed to L{stream} are sent in order over a | ||
stream using the given AMP connection. | ||
""" | ||
sender = AMP() | ||
streams = StreamReceiver() | ||
streamId = interact(AMPStreamReceiver(streams), sender, stream(sender, chunks)) | ||
assert_that(streams.finish(streamId), is_(equal_to(chunks))) |
Empty file.