Skip to content

Commit

Permalink
Merge pull request #806 from luhn/refresh-grant-cors
Browse files Browse the repository at this point in the history
Add CORS support for Refresh Token Grant.
  • Loading branch information
JonathanHuot committed Feb 18, 2022
2 parents 6b1f5db + 47c229c commit f175204
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 18 deletions.
18 changes: 0 additions & 18 deletions oauthlib/oauth2/rfc6749/grant_types/authorization_code.py
Expand Up @@ -10,7 +10,6 @@
from oauthlib import common

from .. import errors
from ..utils import is_secure_transport
from .base import GrantTypeBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -547,20 +546,3 @@ def validate_code_challenge(self, challenge, challenge_method, verifier):
if challenge_method in self._code_challenge_methods:
return self._code_challenge_methods[challenge_method](verifier, challenge)
raise NotImplementedError('Unknown challenge_method %s' % challenge_method)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
18 changes: 18 additions & 0 deletions oauthlib/oauth2/rfc6749/grant_types/base.py
Expand Up @@ -10,6 +10,7 @@
from oauthlib.uri_validate import is_absolute_uri

from ..request_validator import RequestValidator
from ..utils import is_secure_transport

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,3 +249,20 @@ def _handle_redirects(self, request):
raise errors.MissingRedirectURIError(request=request)
if not is_absolute_uri(request.redirect_uri):
raise errors.InvalidRedirectURIError(request=request)

def _create_cors_headers(self, request):
"""If CORS is allowed, create the appropriate headers."""
if 'origin' not in request.headers:
return {}

origin = request.headers['origin']
if not is_secure_transport(origin):
log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin)
return {}
elif not self.request_validator.is_origin_allowed(
request.client_id, origin, request):
log.debug('Invalid origin "%s", CORS not allowed.', origin)
return {}
else:
log.debug('Valid origin "%s", injecting CORS headers.', origin)
return {'Access-Control-Allow-Origin': origin}
1 change: 1 addition & 0 deletions oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
Expand Up @@ -69,6 +69,7 @@ def create_token_response(self, request, token_handler):

log.debug('Issuing new token to client id %r (%r), %r.',
request.client_id, request.client, token)
headers.update(self._create_cors_headers(request))
return headers, json.dumps(token), 200

def validate_token_request(self, request):
Expand Down
1 change: 1 addition & 0 deletions oauthlib/oauth2/rfc6749/request_validator.py
Expand Up @@ -671,6 +671,7 @@ def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
Method is used by:
- Authorization Code Grant
- Refresh Token Grant
"""
return False
41 changes: 41 additions & 0 deletions tests/oauth2/rfc6749/grant_types/test_refresh_token.py
Expand Up @@ -18,6 +18,7 @@ def setUp(self):
self.request = Request('http://a.b/path')
self.request.grant_type = 'refresh_token'
self.request.refresh_token = 'lsdkfhj230'
self.request.client_id = 'abcdef'
self.request.client = mock_client
self.request.scope = 'foo'
self.mock_validator = mock.MagicMock()
Expand Down Expand Up @@ -168,3 +169,43 @@ def test_valid_token_request(self):
del self.request.scope
self.auth.validate_token_request(self.request)
self.assertEqual(self.request.scopes, 'foo bar baz'.split())

# CORS

def test_create_cors_headers(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = True

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertEqual(
headers['Access-Control-Allow-Origin'], 'https://foo.bar'
)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)

def test_create_cors_headers_no_origin(self):
bearer = BearerToken(self.mock_validator)
headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_insecure_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'http://foo.bar'

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_not_called()

def test_create_cors_headers_invalid_origin(self):
bearer = BearerToken(self.mock_validator)
self.request.headers['origin'] = 'https://foo.bar'
self.mock_validator.is_origin_allowed.return_value = False

headers = self.auth.create_token_response(self.request, bearer)[0]
self.assertNotIn('Access-Control-Allow-Origin', headers)
self.mock_validator.is_origin_allowed.assert_called_once_with(
'abcdef', 'https://foo.bar', self.request
)

0 comments on commit f175204

Please sign in to comment.