Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #752 from nsklikas/oidc-refresh
Oidc refresh
- Loading branch information
Showing
7 changed files
with
164 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
OpenID Authorization Code | ||
------------------------- | ||
|
||
.. autoclass:: oauthlib.openid.connect.core.grant_types.RefreshTokenGrant | ||
:members: | ||
:inherited-members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
oauthlib.openid.connect.core.grant_types | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
""" | ||
import logging | ||
|
||
from oauthlib.oauth2.rfc6749.grant_types.refresh_token import ( | ||
RefreshTokenGrant as OAuth2RefreshTokenGrant, | ||
) | ||
|
||
from .base import GrantTypeBase | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class RefreshTokenGrant(GrantTypeBase): | ||
|
||
def __init__(self, request_validator=None, **kwargs): | ||
self.proxy_target = OAuth2RefreshTokenGrant( | ||
request_validator=request_validator, **kwargs) | ||
self.register_token_modifier(self.add_id_token) | ||
|
||
def add_id_token(self, token, token_handler, request): | ||
""" | ||
Construct an initial version of id_token, and let the | ||
request_validator sign or encrypt it. | ||
The authorization_code version of this method is used to | ||
retrieve the nonce accordingly to the code storage. | ||
""" | ||
if not self.request_validator.refresh_id_token(request): | ||
return token | ||
|
||
return super().add_id_token(token, token_handler, request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
tests/openid/connect/core/grant_types/test_refresh_token.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import json | ||
from unittest import mock | ||
|
||
from oauthlib.common import Request | ||
from oauthlib.oauth2.rfc6749.tokens import BearerToken | ||
from oauthlib.openid.connect.core.grant_types import RefreshTokenGrant | ||
|
||
from tests.oauth2.rfc6749.grant_types.test_refresh_token import ( | ||
RefreshTokenGrantTest, | ||
) | ||
from tests.unittest import TestCase | ||
|
||
|
||
def get_id_token_mock(token, token_handler, request): | ||
return "MOCKED_TOKEN" | ||
|
||
|
||
class OpenIDRefreshTokenInterferenceTest(RefreshTokenGrantTest): | ||
"""Test that OpenID don't interfere with normal OAuth 2 flows.""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.auth = RefreshTokenGrant(request_validator=self.mock_validator) | ||
|
||
|
||
class OpenIDRefreshTokenTest(TestCase): | ||
|
||
def setUp(self): | ||
self.request = Request('http://a.b/path') | ||
self.request.grant_type = 'refresh_token' | ||
self.request.refresh_token = 'lsdkfhj230' | ||
self.request.scope = ('hello', 'openid') | ||
self.mock_validator = mock.MagicMock() | ||
|
||
self.mock_validator = mock.MagicMock() | ||
self.mock_validator.authenticate_client.side_effect = self.set_client | ||
self.mock_validator.get_id_token.side_effect = get_id_token_mock | ||
self.auth = RefreshTokenGrant(request_validator=self.mock_validator) | ||
|
||
def set_client(self, request): | ||
request.client = mock.MagicMock() | ||
request.client.client_id = 'mocked' | ||
return True | ||
|
||
def test_refresh_id_token(self): | ||
self.mock_validator.get_original_scopes.return_value = [ | ||
'hello', 'openid' | ||
] | ||
bearer = BearerToken(self.mock_validator) | ||
|
||
headers, body, status_code = self.auth.create_token_response( | ||
self.request, bearer | ||
) | ||
|
||
token = json.loads(body) | ||
self.assertEqual(self.mock_validator.save_token.call_count, 1) | ||
self.assertIn('access_token', token) | ||
self.assertIn('refresh_token', token) | ||
self.assertIn('id_token', token) | ||
self.assertIn('token_type', token) | ||
self.assertIn('expires_in', token) | ||
self.assertEqual(token['scope'], 'hello openid') | ||
self.mock_validator.refresh_id_token.assert_called_once_with( | ||
self.request | ||
) | ||
|
||
def test_refresh_id_token_false(self): | ||
self.mock_validator.refresh_id_token.return_value = False | ||
self.mock_validator.get_original_scopes.return_value = [ | ||
'hello', 'openid' | ||
] | ||
bearer = BearerToken(self.mock_validator) | ||
|
||
headers, body, status_code = self.auth.create_token_response( | ||
self.request, bearer | ||
) | ||
|
||
token = json.loads(body) | ||
self.assertEqual(self.mock_validator.save_token.call_count, 1) | ||
self.assertIn('access_token', token) | ||
self.assertIn('refresh_token', token) | ||
self.assertIn('token_type', token) | ||
self.assertIn('expires_in', token) | ||
self.assertEqual(token['scope'], 'hello openid') | ||
self.assertNotIn('id_token', token) | ||
self.mock_validator.refresh_id_token.assert_called_once_with( | ||
self.request | ||
) | ||
|
||
def test_refresh_token_without_openid_scope(self): | ||
self.request.scope = "hello" | ||
bearer = BearerToken(self.mock_validator) | ||
|
||
headers, body, status_code = self.auth.create_token_response( | ||
self.request, bearer | ||
) | ||
|
||
token = json.loads(body) | ||
self.assertEqual(self.mock_validator.save_token.call_count, 1) | ||
self.assertIn('access_token', token) | ||
self.assertIn('refresh_token', token) | ||
self.assertIn('token_type', token) | ||
self.assertIn('expires_in', token) | ||
self.assertNotIn('id_token', token) | ||
self.assertEqual(token['scope'], 'hello') |