Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CORS support for Refresh Token Grant. #806

Merged
merged 1 commit into from Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 0 additions & 18 deletions oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
Expand Up @@ -10,7 +10,6 @@
from oauthlib import common

from .. import errors
from ..utils import is_secure_transport
from .base import GrantTypeBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -547,20 +546,3 @@ def validate_code_challenge(self, challenge, challenge_method, verifier):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
18 changes: 18 additions & 0 deletions oauthlib/oauth2/rfc6749/grant_types/base.py
Expand Up @@ -10,6 +10,7 @@
from oauthlib.uri_validate import is_absolute_uri

from ..request_validator import RequestValidator
from ..utils import is_secure_transport

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,3 +249,20 @@ def _handle_redirects(self, request):
raise errors.MissingRedirectURIError(request=request)
if not is_absolute_uri(request.redirect_uri):
raise errors.InvalidRedirectURIError(request=request)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
1 change: 1 addition & 0 deletions oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
Expand Up @@ -69,6 +69,7 @@ def create_token_response(self, request, token_handler):

log.debug('Issuing new token to client id %r (%r), %r.',
request.client_id, request.client, token)
headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200

def validate_token_request(self, request):
Expand Down
1 change: 1 addition & 0 deletions oauthlib/oauth2/rfc6749/request_validator.py
Expand Up @@ -671,6 +671,7 @@ def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):

Method is used by:
- Authorization Code Grant
- Refresh Token Grant

"""
return False
41 changes: 41 additions & 0 deletions tests/oauth2/rfc6749/grant_types/test_refresh_token.py
Expand Up @@ -18,6 +18,7 @@ def setUp(self):
self.request = Request('http://a.b/path')
self.request.grant_type = 'refresh_token'
self.request.refresh_token = 'lsdkfhj230'
self.request.client_id = 'abcdef'
self.request.client = mock_client
self.request.scope = 'foo'
self.mock_validator = mock.MagicMock()
Expand Down Expand Up @@ -168,3 +169,43 @@ def test_valid_token_request(self):
del self.request.scope
self.auth.validate_token_request(self.request)
self.assertEqual(self.request.scopes, 'foo bar baz'.split())

# CORS

def test_create_cors_headers(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = True

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertEqual(
headers['Access-Control-Allow-Origin'], 'https://foo.bar'
)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)

def test_create_cors_headers_no_origin(self):
bearer = BearerToken(self.mock_validator)
headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_insecure_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'http://foo.bar'

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_invalid_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = False

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)