From 43dc6e57e9a243065a0d1f1d51fe8257ab51d7c2 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 23 May 2021 16:26:39 +0100 Subject: [PATCH] Pass auth data from client in connect event handler (Fixes #1555) --- docs/index.rst | 19 ++++++++-- flask_socketio/__init__.py | 10 ++++-- flask_socketio/test_client.py | 11 +++--- test_socketio.py | 65 +++++++++++++++++++---------------- 4 files changed, 66 insertions(+), 39 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 151148c6..27716b07 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -164,8 +164,8 @@ Custom named events can also support multiple arguments:: print('received args: ' + arg1 + arg2 + arg3) When the name of the event is a valid Python identifier that does not collide -with other defined symbols, the ``@socketio.event`` provides a more compact -syntax that takes the event name from the decorated function:: +with other defined symbols, the ``@socketio.event`` decorator provides a more +compact syntax that takes the event name from the decorated function:: @socketio.event def my_custom_event(arg1, arg2, arg3): @@ -345,13 +345,19 @@ Flask-SocketIO also dispatches connection and disconnection events. The following example shows how to register handlers for them:: @socketio.on('connect') - def test_connect(): + def test_connect(auth): emit('my response', {'data': 'Connected'}) @socketio.on('disconnect') def test_disconnect(): print('Client disconnected') +The ``auth`` argument in the connection handler is optional. The client can +use it to pass authentication data such as tokens in dictionary format. If the +client does not provide authentication details, then this argument is set to +``None``. If the server defines a connection event handler without this +argument, then any authentication data passed by the cient is discarded. + The connection event handler can return ``False`` to reject the connection, or it can also raise `ConectionRefusedError`. This is so that the client can be authenticated at this point. When using the exception, any arguments passed to @@ -517,6 +523,13 @@ user's identity can then be recorded in the user session or in a cookie, and later when the SocketIO connection is established that information will be accessible to SocketIO event handlers. +Recent revisions of the Socket.IO protocol include the ability to pass a +dictionary with authentication information during the connection. This is an +ideal place for the client to include a token or other authentication details. +If the client uses this capability, the server will provide this dictionary as +an argument to the ``connect`` event handler, as shown above. + + Using Flask-Login with Flask-SocketIO ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/flask_socketio/__init__.py b/flask_socketio/__init__.py index 565b627f..507329db 100644 --- a/flask_socketio/__init__.py +++ b/flask_socketio/__init__.py @@ -708,7 +708,7 @@ def sleep(self, seconds=0): return self.server.sleep(seconds) def test_client(self, app, namespace=None, query_string=None, - headers=None, flask_test_client=None): + headers=None, auth=None, flask_test_client=None): """The Socket.IO test client is useful for testing a Flask-SocketIO server. It works in a similar way to the Flask Test Client, but adapted to the Socket.IO server. @@ -719,6 +719,7 @@ def test_client(self, app, namespace=None, query_string=None, namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. + :param auth: Optional authentication data, given as a dictionary. :param flask_test_client: The instance of the Flask test client currently in use. Passing the Flask test client is optional, but is necessary if you @@ -728,6 +729,7 @@ def test_client(self, app, namespace=None, query_string=None, """ return SocketIOTestClient(app, self, namespace=namespace, query_string=query_string, headers=headers, + auth=auth, flask_test_client=flask_test_client) def _handle_event(self, handler, message, namespace, sid, *args): @@ -756,7 +758,11 @@ def _handle_event(self, handler, message, namespace, sid, *args): flask.request.event = {'message': message, 'args': args} try: if message == 'connect': - ret = handler() + auth = args[1] if len(args) > 1 else None + try: + ret = handler(auth) + except TypeError: + ret = handler() else: ret = handler(*args) except: diff --git a/flask_socketio/test_client.py b/flask_socketio/test_client.py index 84d3f564..f0f5b8f8 100644 --- a/flask_socketio/test_client.py +++ b/flask_socketio/test_client.py @@ -16,6 +16,7 @@ class SocketIOTestClient(object): connects to the server on the global namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. + :param auth: Optional authentication data, given as a dictionary. :param flask_test_client: The instance of the Flask test client currently in use. Passing the Flask test client is optional, but is necessary if you @@ -27,7 +28,7 @@ class SocketIOTestClient(object): acks = {} def __init__(self, app, socketio, namespace=None, query_string=None, - headers=None, flask_test_client=None): + headers=None, auth=None, flask_test_client=None): def _mock_send_packet(eio_sid, pkt): # make sure the packet can be encoded and decoded epkt = pkt.encode() @@ -76,7 +77,7 @@ def _mock_send_packet(eio_sid, pkt): 'configuration.') socketio.server.manager.initialize() self.connect(namespace=namespace, query_string=query_string, - headers=headers) + headers=headers, auth=auth) def is_connected(self, namespace=None): """Check if a namespace is connected. @@ -86,7 +87,8 @@ def is_connected(self, namespace=None): """ return self.connected.get(namespace or '/', False) - def connect(self, namespace=None, query_string=None, headers=None): + def connect(self, namespace=None, query_string=None, headers=None, + auth=None): """Connect the client. :param namespace: The namespace for the client. If not provided, the @@ -94,6 +96,7 @@ def connect(self, namespace=None, query_string=None, headers=None): namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. + :param auth: Optional authentication data, given as a dictionary. Note that it is usually not necessary to explicitly call this method, since a connection is automatically established when an instance of @@ -112,7 +115,7 @@ def connect(self, namespace=None, query_string=None, headers=None): # inject cookies from Flask self.flask_test_client.cookie_jar.inject_wsgi(environ) self.socketio.server._handle_eio_connect(self.eio_sid, environ) - pkt = packet.Packet(packet.CONNECT, namespace=namespace) + pkt = packet.Packet(packet.CONNECT, auth, namespace=namespace) with self.app.app_context(): self.socketio.server._handle_eio_message(self.eio_sid, pkt.encode()) diff --git a/test_socketio.py b/test_socketio.py index 8d982e23..d58c59ac 100755 --- a/test_socketio.py +++ b/test_socketio.py @@ -16,7 +16,9 @@ @socketio.on('connect') -def on_connect(): +def on_connect(auth): + if auth != {'foo': 'bar'}: # pragma: no cover + return False if request.args.get('fail'): return False send('connected') @@ -278,8 +280,8 @@ def tearDown(self): pass def test_connect(self): - client = socketio.test_client(app) - client2 = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) self.assertTrue(client.is_connected()) self.assertTrue(client2.is_connected()) self.assertNotEqual(client.eio_sid, client2.eio_sid) @@ -297,7 +299,8 @@ def test_connect(self): def test_connect_query_string_and_headers(self): client = socketio.test_client( app, query_string='?foo=bar&foo=baz', - headers={'Authorization': 'Bearer foobar'}) + headers={'Authorization': 'Bearer foobar'}, + auth={'foo': 'bar'}) received = client.get_received() self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected') @@ -329,13 +332,14 @@ def test_connect_namespace_query_string_and_headers(self): client.disconnect(namespace='/test') def test_connect_rejected(self): - client = socketio.test_client(app, query_string='fail=1') + client = socketio.test_client(app, query_string='fail=1', + auth={'foo': 'bar'}) self.assertFalse(client.is_connected()) def test_disconnect(self): global disconnected disconnected = None - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.disconnect() self.assertEqual(disconnected, '/') @@ -347,7 +351,7 @@ def test_disconnect_namespace(self): self.assertEqual(disconnected, '/test') def test_send(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() client.send('echo this message back') received = client.get_received() @@ -355,8 +359,8 @@ def test_send(self): self.assertEqual(received[0]['args'], 'echo this message back') def test_send_json(self): - client1 = socketio.test_client(app) - client2 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) client1.get_received() client2.get_received() client1.send({'a': 'b'}, json=True) @@ -384,7 +388,7 @@ def test_send_json_namespace(self): self.assertEqual(received[0]['args']['a'], 'b') def test_emit(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() client.emit('my custom event', {'a': 'b'}) received = client.get_received() @@ -394,7 +398,7 @@ def test_emit(self): self.assertEqual(received[0]['args'][0]['a'], 'b') def test_emit_binary(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() client.emit('my custom event', {u'a': b'\x01\x02\x03'}) received = client.get_received() @@ -404,7 +408,7 @@ def test_emit_binary(self): self.assertEqual(received[0]['args'][0]['a'], b'\x01\x02\x03') def test_request_event_data(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() global request_event_data request_event_data = None @@ -427,8 +431,8 @@ def test_emit_namespace(self): self.assertEqual(received[0]['args'][0]['a'], 'b') def test_broadcast(self): - client1 = socketio.test_client(app) - client2 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) client3 = socketio.test_client(app, namespace='/test') client2.get_received() client3.get_received('/test') @@ -443,7 +447,7 @@ def test_broadcast(self): def test_broadcast_namespace(self): client1 = socketio.test_client(app, namespace='/test') client2 = socketio.test_client(app, namespace='/test') - client3 = socketio.test_client(app) + client3 = socketio.test_client(app, auth={'foo': 'bar'}) client2.get_received('/test') client3.get_received() client1.emit('my custom broadcast namespace event', {'a': 'b'}, @@ -458,7 +462,8 @@ def test_broadcast_namespace(self): def test_session(self): flask_client = app.test_client() flask_client.get('/session') - client = socketio.test_client(app, flask_test_client=flask_client) + client = socketio.test_client(app, flask_test_client=flask_client, + auth={'foo': 'bar'}) client.get_received() client.send('echo this message back') self.assertEqual( @@ -470,8 +475,8 @@ def test_session(self): {'a': 'b', 'foo': 'bar'}) def test_room(self): - client1 = socketio.test_client(app) - client2 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) client3 = socketio.test_client(app, namespace='/test') client1.get_received() client2.get_received() @@ -516,7 +521,7 @@ def test_room(self): self.assertEqual(len(received), 0) def test_error_handling(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() global error_testing error_testing = False @@ -540,9 +545,9 @@ def test_error_handling_default(self): self.assertTrue(error_testing_default) def test_ack(self): - client1 = socketio.test_client(app) - client2 = socketio.test_client(app) - client3 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) + client3 = socketio.test_client(app, auth={'foo': 'bar'}) ack = client1.send('echo this message back', callback=True) self.assertEqual(ack, 'echo this message back') ack = client1.send('test noackargs', callback=True) @@ -556,9 +561,9 @@ def test_ack(self): self.assertEqual(ack3, {'a': 'b'}) def test_noack(self): - client1 = socketio.test_client(app) - client2 = socketio.test_client(app) - client3 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) + client2 = socketio.test_client(app, auth={'foo': 'bar'}) + client3 = socketio.test_client(app, auth={'foo': 'bar'}) no_ack_dict = {'noackargs': True} noack = client1.send("test noackargs", callback=False) self.assertIsNone(noack) @@ -568,7 +573,7 @@ def test_noack(self): self.assertIsNone(noack3) def test_error_handling_ack(self): - client1 = socketio.test_client(app) + client1 = socketio.test_client(app, auth={'foo': 'bar'}) client2 = socketio.test_client(app, namespace='/test') client3 = socketio.test_client(app, namespace='/unused_namespace') errorack = client1.emit("error testing", "", callback=True) @@ -582,7 +587,7 @@ def test_error_handling_ack(self): self.assertEqual(errorack_default, 'error/default') def test_on_event(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() global request_event_data request_event_data = None @@ -684,13 +689,13 @@ def on_connect(): self.assertFalse(socketio.server.eio.allow_upgrades) self.assertEqual(socketio.server.eio.cookie, 'foo') - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) received = client.get_received() self.assertEqual(len(received), 1) self.assertEqual(received[0]['args'], {'connected': 'foo'}) def test_encode_decode(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) client.get_received() data = {'foo': 'bar', 'invalid': socketio} self.assertRaises(TypeError, client.emit, 'my custom event', data, @@ -704,7 +709,7 @@ def test_encode_decode(self): self.assertEqual(received[0]['args'][0], {'foo': 'bar'}) def test_encode_decode_2(self): - client = socketio.test_client(app) + client = socketio.test_client(app, auth={'foo': 'bar'}) self.assertRaises(TypeError, client.emit, 'bad response') self.assertRaises(TypeError, client.emit, 'bad callback', callback=True)