diff --git a/changelog.d/327.bugfix.rst b/changelog.d/327.bugfix.rst new file mode 100644 index 00000000..e6391df7 --- /dev/null +++ b/changelog.d/327.bugfix.rst @@ -0,0 +1 @@ +``treq.testing.StubTreq`` now persists ``twisted.web.server.Session`` instances between requests. diff --git a/docs/testing.rst b/docs/testing.rst index 2f292b51..855bc69c 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -57,3 +57,4 @@ This is superior to calling your resource's methods directly or passing mock obj Thus, the ``request`` object your code interacts with is a *real* :class:`twisted.web.server.Request` and behaves the same as it would in production. Note that if your resource returns :data:`~twisted.web.server.NOT_DONE_YET` you must keep a reference to the :class:`~treq.testing.RequestTraversalAgent` and call its :meth:`~treq.testing.RequestTraversalAgent.flush()` method to spin the memory reactor once the server writes additional data before the client will receive it. + diff --git a/src/treq/test/test_testing.py b/src/treq/test/test_testing.py index eb804158..1615472a 100644 --- a/src/treq/test/test_testing.py +++ b/src/treq/test/test_testing.py @@ -52,6 +52,34 @@ def render(self, request): return NOT_DONE_YET +class _SessionIdTestResource(Resource): + """ + Resource that returns the current session ID. + """ + isLeaf = True + + def __init__(self): + super().__init__() + # keep track of all sessions created, so we can manually expire them later + self.sessions = [] + + def render(self, request): + session = request.getSession() + if session not in self.sessions: + # new session, add to internal list + self.sessions.append(session) + uid = session.uid + return uid + + def expire_sessions(self): + """ + Manually expire all sessions created by this resource. + """ + for session in self.sessions: + session.expire() + self.sessions = [] + + class StubbingTests(TestCase): """ Tests for :class:`StubTreq`. @@ -242,6 +270,40 @@ def test_handles_successful_asynchronous_requests_with_streaming(self): stub.flush() self.successResultOf(d) + def test_session_persistence_between_requests(self): + """ + Calling request.getSession() in the wrapped resource will return + a session with the same ID, until the sessions are cleaned. + """ + rsrc = _SessionIdTestResource() + stub = StubTreq(rsrc) + # request 1, getting original session ID + d = stub.request("method", "http://example.com/") + resp = self.successResultOf(d) + cookies = resp.cookies() + sid_1 = self.successResultOf(resp.content()) + # request 2, ensuring session ID stays the same + d = stub.request("method", "http://example.com/", cookies=cookies) + resp = self.successResultOf(d) + sid_2 = self.successResultOf(resp.content()) + self.assertEqual(sid_1, sid_2) + # request 3, ensuring the session IDs are different after cleaning + # or expiring the sessions + + # manually expire the sessions. + rsrc.expire_sessions() + + d = stub.request("method", "http://example.com/") + resp = self.successResultOf(d) + cookies = resp.cookies() + sid_3 = self.successResultOf(resp.content()) + self.assertNotEqual(sid_1, sid_3) + # request 4, ensuring that once again the session IDs are the same + d = stub.request("method", "http://example.com/", cookies=cookies) + resp = self.successResultOf(d) + sid_4 = self.successResultOf(resp.content()) + self.assertEqual(sid_3, sid_4) + class HasHeadersTests(TestCase): """ diff --git a/src/treq/testing.py b/src/treq/testing.py index df633beb..33d8c94d 100644 --- a/src/treq/testing.py +++ b/src/treq/testing.py @@ -26,7 +26,7 @@ from twisted.web.error import SchemeNotSupported from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer from twisted.web.resource import Resource -from twisted.web.server import Site +from twisted.web.server import Session, Site from zope.interface import directlyProvides, implementer @@ -88,6 +88,12 @@ def __init__(self, rootResource): reactor=self._memoryReactor, endpointFactory=_EndpointFactory(self._memoryReactor)) self._rootResource = rootResource + self._serverFactory = Site(self._rootResource, reactor=self._memoryReactor) + self._serverFactory.sessionFactory = lambda site, uid: Session( + site, + uid, + reactor=self._memoryReactor, + ) self._pumps = set() def request(self, method, uri, headers=None, bodyProducer=None): @@ -126,8 +132,7 @@ def check_already_called(r): # Create the protocol and fake transport for the client and server, # using the factory that was passed to the MemoryReactor for the # client, and a Site around our rootResource for the server. - serverFactory = Site(self._rootResource, reactor=self._memoryReactor) - serverProtocol = serverFactory.buildProtocol(clientAddress) + serverProtocol = self._serverFactory.buildProtocol(clientAddress) serverTransport = iosim.FakeTransport( serverProtocol, isServer=True, hostAddress=serverAddress, peerAddress=clientAddress) @@ -228,8 +233,8 @@ def __init__(self, resource): :param resource: A :obj:`Resource` object that provides the fake responses """ - _agent = RequestTraversalAgent(resource) - _client = HTTPClient(agent=_agent, + self._agent = RequestTraversalAgent(resource) + _client = HTTPClient(agent=self._agent, data_to_body_producer=_SynchronousProducer) for function_name in treq.__all__: function = getattr(_client, function_name, None) @@ -239,7 +244,7 @@ def __init__(self, resource): function = _reject_files(function) setattr(self, function_name, function) - self.flush = _agent.flush + self.flush = self._agent.flush class StringStubbingResource(Resource):