Skip to content

Commit

Permalink
Address potential websocket cross-origin attacks (Fixes #128)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 28, 2019
1 parent f23a405 commit b316510
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 24 deletions.
22 changes: 19 additions & 3 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class AsyncServer(server.Server):
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: List of origins that are allowed to connect
to this server. All origins are allowed by
default.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same server
is allowed by default. Set this argument to
``'*'`` to allow all origins.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
Expand Down Expand Up @@ -181,6 +182,21 @@ async def handle_request(self, *args, **kwargs):
environ = await translate_request(*args, **kwargs)
else:
environ = translate_request(*args, **kwargs)

# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since browsers
# only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
make_response = self._async['make_response']
response = make_response(r['status'], r['headers'],
r['response'], environ)
return response

method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

Expand Down
53 changes: 36 additions & 17 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class Server(object):
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. All origins are
allowed by default, which is equivalent to
setting this argument to ``'*'``.
connect to this server. Only the same server
is allowed by default. Set this argument to
``'*'`` to allow all origins.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
Expand Down Expand Up @@ -309,6 +309,18 @@ def handle_request(self, environ, start_response):
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since browsers
# only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
start_response(r['status'], r['headers'])
return [r['response']]

method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

Expand Down Expand Up @@ -572,27 +584,34 @@ def _unauthorized(self):
'headers': [('Content-Type', 'text/plain')],
'response': b'Unauthorized'}

def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if isinstance(self.cors_allowed_origins, six.string_types):
if self.cors_allowed_origins == '*':
allowed_origins = None
else:
allowed_origins = [self.cors_allowed_origins]
def _cors_allowed_origins(self, environ):
default_origin = None
if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
default_origin = '{scheme}://{host}'.format(
scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST'])
if self.cors_allowed_origins is None:
allowed_origins = [default_origin] \
if default_origin is not None else[]

This comment has been minimized.

Copy link
@Ambro17

Ambro17 Jul 29, 2019

Missing whitespace?

This comment has been minimized.

Copy link
@miguelgrinberg

miguelgrinberg Jul 29, 2019

Author Owner

Oh wow. This is actually covered in unit tests, and it does work fine without that space! It also was missed by flake8. Interesting, I'll fix it.

elif self.cors_allowed_origins == '*':
allowed_origins = None
elif isinstance(self.cors_allowed_origins, six.string_types):
allowed_origins = [self.cors_allowed_origins]
else:
allowed_origins = self.cors_allowed_origins
if allowed_origins is not None and \
environ.get('HTTP_ORIGIN', '') not in allowed_origins:
return []
if 'HTTP_ORIGIN' in environ:
return allowed_origins

def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
headers = []
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is None or ('HTTP_ORIGIN' in environ and \
environ['HTTP_ORIGIN'] in allowed_origins):
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
else:
headers = [('Access-Control-Allow-Origin', '*')]
if environ['REQUEST_METHOD'] == 'OPTIONS':
headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
headers += [('Access-Control-Allow-Headers',
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
if self.cors_credentials:
headers += [('Access-Control-Allow-Credentials', 'true')]
return headers
Expand Down
47 changes: 46 additions & 1 deletion tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def test_connect_cors_headers(self, import_module):
s = asyncio_server.AsyncServer()
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
Expand All @@ -423,6 +422,52 @@ def test_connect_cors_not_allowed_origin(self, import_module):
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_all_origins(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='*')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'foo'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_one_origin(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'a'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='a')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'a'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_one_origin_not_allowed(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'b'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='a')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_headers_default_origin(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'wsgi.url_scheme': 'http',
'HTTP_HOST': 'foo',
'HTTP_ORIGIN': 'http://foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'http://foo'),
headers)

@mock.patch('importlib.import_module')
def test_connect_cors_no_credentials(self, import_module):
a = self.get_async_mock()
Expand Down
18 changes: 15 additions & 3 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def test_connect_cors_headers(self):
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

def test_connect_cors_allowed_origin(self):
Expand All @@ -577,11 +576,12 @@ def test_connect_cors_not_allowed_origin(self):

def test_connect_cors_headers_all_origins(self):
s = server.Server(cors_allowed_origins='*')
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Origin', 'foo'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

def test_connect_cors_headers_one_origin(self):
Expand All @@ -604,6 +604,18 @@ def test_connect_cors_headers_one_origin_not_allowed(self):
self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

def test_connect_cors_headers_default_origin(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'wsgi.url_scheme': 'http', 'HTTP_HOST': 'foo',
'HTTP_ORIGIN': 'http://foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'http://foo'),
headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

def test_connect_cors_no_credentials(self):
s = server.Server(cors_credentials=False)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
Expand Down

0 comments on commit b316510

Please sign in to comment.