Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treq.testing.StubTreq: fix persisting twisted.web.server.Session objects between requests #328

Merged
merged 5 commits into from May 13, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/327.bugfix.rst
@@ -0,0 +1 @@
``treq.testing.StubTreq`` now persists ``twisted.web.server.Session`` instances between requests.
1 change: 1 addition & 0 deletions docs/testing.rst
Expand Up @@ -57,3 +57,4 @@ This is superior to calling your resource's methods directly or passing mock obj
Thus, the ``request`` object your code interacts with is a *real* :class:`twisted.web.server.Request` and behaves the same as it would in production.

Note that if your resource returns :data:`~twisted.web.server.NOT_DONE_YET` you must keep a reference to the :class:`~treq.testing.RequestTraversalAgent` and call its :meth:`~treq.testing.RequestTraversalAgent.flush()` method to spin the memory reactor once the server writes additional data before the client will receive it.

62 changes: 62 additions & 0 deletions src/treq/test/test_testing.py
Expand Up @@ -52,6 +52,34 @@ def render(self, request):
return NOT_DONE_YET


class _SessionIdTestResource(Resource):
"""
Resource that returns the current session ID.
"""
isLeaf = True

def __init__(self):
super().__init__()
# keep track of all sessions created, so we can manually expire them later
self.sessions = []

def render(self, request):
session = request.getSession()
if session not in self.sessions:
# new session, add to internal list
self.sessions.append(session)
uid = session.uid
return uid

def expire_sessions(self):
"""
Manually expire all sessions created by this resource.
"""
for session in self.sessions:
session.expire()
self.sessions = []


class StubbingTests(TestCase):
"""
Tests for :class:`StubTreq`.
Expand Down Expand Up @@ -242,6 +270,40 @@ def test_handles_successful_asynchronous_requests_with_streaming(self):
stub.flush()
self.successResultOf(d)

def test_session_persistence_between_requests(self):
"""
Calling request.getSession() in the wrapped resource will return
a session with the same ID, until the sessions are cleaned.
"""
rsrc = _SessionIdTestResource()
stub = StubTreq(rsrc)
# request 1, getting original session ID
d = stub.request("method", "http://example.com/")
resp = self.successResultOf(d)
cookies = resp.cookies()
sid_1 = self.successResultOf(resp.content())
# request 2, ensuring session ID stays the same
d = stub.request("method", "http://example.com/", cookies=cookies)
resp = self.successResultOf(d)
sid_2 = self.successResultOf(resp.content())
self.assertEqual(sid_1, sid_2)
# request 3, ensuring the session IDs are different after cleaning
# or expiring the sessions

# manually expire the sessions.
rsrc.expire_sessions()

d = stub.request("method", "http://example.com/")
resp = self.successResultOf(d)
cookies = resp.cookies()
sid_3 = self.successResultOf(resp.content())
self.assertNotEqual(sid_1, sid_3)
# request 4, ensuring that once again the session IDs are the same
d = stub.request("method", "http://example.com/", cookies=cookies)
resp = self.successResultOf(d)
sid_4 = self.successResultOf(resp.content())
self.assertEqual(sid_3, sid_4)


class HasHeadersTests(TestCase):
"""
Expand Down
13 changes: 7 additions & 6 deletions src/treq/testing.py
Expand Up @@ -26,7 +26,7 @@
from twisted.web.error import SchemeNotSupported
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.server import Session, Site

from zope.interface import directlyProvides, implementer

Expand Down Expand Up @@ -88,6 +88,8 @@ def __init__(self, rootResource):
reactor=self._memoryReactor,
endpointFactory=_EndpointFactory(self._memoryReactor))
self._rootResource = rootResource
self._serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that you can cause the session to use the correct reactor by overriding sessionFactory here like this:

Suggested change
self._serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
self._serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
self._serverFactory.sessionFactory = lambda site, uid: Session(site, uid, reactor=self._memoryReactor)

Then you don't need the cleanup code in the test you added, and more importantly you don't need cleanup code in any test that uses RequestTraversalAgent.

self._serverFactory.sessionFactory = lambda site, uid: Session(site, uid, reactor=self._memoryReactor)
self._pumps = set()

def request(self, method, uri, headers=None, bodyProducer=None):
Expand Down Expand Up @@ -126,8 +128,7 @@ def check_already_called(r):
# Create the protocol and fake transport for the client and server,
# using the factory that was passed to the MemoryReactor for the
# client, and a Site around our rootResource for the server.
serverFactory = Site(self._rootResource, reactor=self._memoryReactor)
serverProtocol = serverFactory.buildProtocol(clientAddress)
serverProtocol = self._serverFactory.buildProtocol(clientAddress)
serverTransport = iosim.FakeTransport(
serverProtocol, isServer=True,
hostAddress=serverAddress, peerAddress=clientAddress)
Expand Down Expand Up @@ -228,8 +229,8 @@ def __init__(self, resource):
:param resource: A :obj:`Resource` object that provides the fake
responses
"""
_agent = RequestTraversalAgent(resource)
_client = HTTPClient(agent=_agent,
self._agent = RequestTraversalAgent(resource)
_client = HTTPClient(agent=self._agent,
data_to_body_producer=_SynchronousProducer)
for function_name in treq.__all__:
function = getattr(_client, function_name, None)
Expand All @@ -239,7 +240,7 @@ def __init__(self, resource):
function = _reject_files(function)

setattr(self, function_name, function)
self.flush = _agent.flush
self.flush = self._agent.flush


class StringStubbingResource(Resource):
Expand Down