Skip to content

Commit

Permalink
Add support for CORS in the token endpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
luhn authored and auvipy committed Dec 13, 2021
1 parent ea5ef62 commit 55ce48b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
19 changes: 19 additions & 0 deletions oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
Expand Up @@ -10,6 +10,7 @@
from oauthlib import common

from .. import errors
from ..utils import is_secure_transport
from .base import GrantTypeBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -312,6 +313,7 @@ def create_token_response(self, request, token_handler):
self.request_validator.save_token(token, request)
self.request_validator.invalidate_authorization_code(
request.client_id, request.code, request)
headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200

def validate_authorization_request(self, request):
Expand Down Expand Up @@ -545,3 +547,20 @@ def validate_code_challenge(self, challenge, challenge_method, verifier):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
25 changes: 25 additions & 0 deletions oauthlib/oauth2/rfc6749/request_validator.py
Expand Up @@ -649,3 +649,28 @@ def get_code_challenge_method(self, code, request):
"""
raise NotImplementedError('Subclasses must implement this method.')

def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
"""Indicate if the given origin is allowed to access the token endpoint
via Cross-Origin Resource Sharing (CORS). CORS is used by browser-based
clients, such as Single-Page Applications, to perform the Authorization
Code Grant.
(Note: If performing Authorization Code Grant via a public client such
as a browser, you should use PKCE as well.)
If this method returns true, the appropriate CORS headers will be added
to the response. By default this method always returns False, meaning
CORS is disabled.
:param client_id: Unicode client identifier.
:param redirect_uri: Unicode origin.
:param request: OAuthlib request.
:type request: oauthlib.common.Request
:rtype: bool
Method is used by:
- Authorization Code Grant
"""
return False
41 changes: 41 additions & 0 deletions tests/oauth2/rfc6749/grant_types/test_authorization_code.py
Expand Up @@ -28,6 +28,7 @@ def setUp(self):
self.mock_validator = mock.MagicMock()
self.mock_validator.is_pkce_required.return_value = False
self.mock_validator.get_code_challenge.return_value = None
self.mock_validator.is_origin_allowed.return_value = False
self.mock_validator.authenticate_client.side_effect = self.set_client
self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator)

Expand Down Expand Up @@ -339,3 +340,43 @@ def test_hybrid_token_save(self):
)
self.auth.create_authorization_response(self.request, bearer)
self.mock_validator.save_token.assert_called_once()

# CORS

def test_create_cors_headers(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = True

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertEqual(
headers['Access-Control-Allow-Origin'], 'https://foo.bar'
)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)

def test_create_cors_headers_no_origin(self):
bearer = BearerToken(self.mock_validator)
headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_insecure_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'http://foo.bar'

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_invalid_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = False

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)
3 changes: 3 additions & 0 deletions tests/oauth2/rfc6749/test_request_validator.py
Expand Up @@ -46,3 +46,6 @@ def test_method_contracts(self):
self.assertRaises(NotImplementedError, v.validate_user,
'username', 'password', 'client', 'request')
self.assertTrue(v.client_authentication_required('r'))
self.assertFalse(
v.is_origin_allowed('client_id', 'https://foo.bar', 'r')
)

0 comments on commit 55ce48b

Please sign in to comment.