Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CORS support to token endpoint. #791

Merged
merged 1 commit into from Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
Expand Up @@ -10,6 +10,7 @@
from oauthlib import common

from .. import errors
from ..utils import is_secure_transport
from .base import GrantTypeBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -312,6 +313,7 @@ def create_token_response(self, request, token_handler):
self.request_validator.save_token(token, request)
self.request_validator.invalidate_authorization_code(
request.client_id, request.code, request)
headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200

def validate_authorization_request(self, request):
Expand Down Expand Up @@ -545,3 +547,20 @@ def validate_code_challenge(self, challenge, challenge_method, verifier):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
25 changes: 25 additions & 0 deletions oauthlib/oauth2/rfc6749/request_validator.py
Expand Up @@ -649,3 +649,28 @@ def get_code_challenge_method(self, code, request):

"""
raise NotImplementedError('Subclasses must implement this method.')

def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
"""Indicate if the given origin is allowed to access the token endpoint
via Cross-Origin Resource Sharing (CORS). CORS is used by browser-based
clients, such as Single-Page Applications, to perform the Authorization
Code Grant.

(Note: If performing Authorization Code Grant via a public client such
as a browser, you should use PKCE as well.)

If this method returns true, the appropriate CORS headers will be added
to the response. By default this method always returns False, meaning
CORS is disabled.

:param client_id: Unicode client identifier.
:param redirect_uri: Unicode origin.
:param request: OAuthlib request.
:type request: oauthlib.common.Request
:rtype: bool

Method is used by:
- Authorization Code Grant

"""
return False
41 changes: 41 additions & 0 deletions tests/oauth2/rfc6749/grant_types/test_authorization_code.py
Expand Up @@ -28,6 +28,7 @@ def setUp(self):
self.mock_validator = mock.MagicMock()
self.mock_validator.is_pkce_required.return_value = False
self.mock_validator.get_code_challenge.return_value = None
self.mock_validator.is_origin_allowed.return_value = False
self.mock_validator.authenticate_client.side_effect = self.set_client
self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator)

Expand Down Expand Up @@ -339,3 +340,43 @@ def test_hybrid_token_save(self):
)
self.auth.create_authorization_response(self.request, bearer)
self.mock_validator.save_token.assert_called_once()

# CORS

def test_create_cors_headers(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = True

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertEqual(
headers['Access-Control-Allow-Origin'], 'https://foo.bar'
)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)

def test_create_cors_headers_no_origin(self):
bearer = BearerToken(self.mock_validator)
headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_insecure_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'http://foo.bar'

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_invalid_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = False

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)
3 changes: 3 additions & 0 deletions tests/oauth2/rfc6749/test_request_validator.py
Expand Up @@ -46,3 +46,6 @@ def test_method_contracts(self):
self.assertRaises(NotImplementedError, v.validate_user,
'username', 'password', 'client', 'request')
self.assertTrue(v.client_authentication_required('r'))
self.assertFalse(
v.is_origin_allowed('client_id', 'https://foo.bar', 'r')
)