Skip to content

Commit

Permalink
Merge pull request #752 from nsklikas/oidc-refresh
Browse files Browse the repository at this point in the history
Oidc refresh
  • Loading branch information
JonathanHuot committed Aug 12, 2021
2 parents 4fddf07 + f6b6258 commit 555e3b0
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -7,6 +7,9 @@ OAuth2.0 Provider - Bugfixes

* #753: Fix acceptance of valid IPv6 addresses in URI validation

OAuth2.0 Provider - Features
* #751: OIDC add support of refreshing ID Tokens

OAuth2.0 Client - Bugfixes

* #730: Base OAuth2 Client now has a consistent way of managing the `scope`: it consistently
Expand All @@ -25,6 +28,8 @@ OAuth2.0 Provider - Bugfixes
* #746: OpenID Connect Hybrid - fix nonce not passed to add_id_token
* #756: Different prompt values are now handled according to spec (e.g. prompt=none)
* #759: OpenID Connect - fix Authorization: Basic parsing
* #751: The RefreshTokenGrant modifiers now take the same arguments as the
AuthorizationCodeGrant modifiers (`token`, `token_handler`, `request`).

General
* #716: improved skeleton validator for public vs private client
Expand Down
6 changes: 6 additions & 0 deletions docs/oauth2/oidc/refresh_token.rst
@@ -0,0 +1,6 @@
OpenID Authorization Code
-------------------------

.. autoclass:: oauthlib.openid.connect.core.grant_types.RefreshTokenGrant
:members:
:inherited-members:
2 changes: 1 addition & 1 deletion oauthlib/oauth2/rfc6749/grant_types/refresh_token.py
Expand Up @@ -63,7 +63,7 @@ def create_token_response(self, request, token_handler):
refresh_token=self.issue_new_refresh_tokens)

for modifier in self._token_modifiers:
token = modifier(token)
token = modifier(token, token_handler, request)

self.request_validator.save_token(token, request)

Expand Down
1 change: 1 addition & 0 deletions oauthlib/openid/connect/core/grant_types/__init__.py
Expand Up @@ -10,3 +10,4 @@
)
from .hybrid import HybridGrant
from .implicit import ImplicitGrant
from .refresh_token import RefreshTokenGrant
34 changes: 34 additions & 0 deletions oauthlib/openid/connect/core/grant_types/refresh_token.py
@@ -0,0 +1,34 @@
"""
oauthlib.openid.connect.core.grant_types
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
import logging

from oauthlib.oauth2.rfc6749.grant_types.refresh_token import (
RefreshTokenGrant as OAuth2RefreshTokenGrant,
)

from .base import GrantTypeBase

log = logging.getLogger(__name__)


class RefreshTokenGrant(GrantTypeBase):

def __init__(self, request_validator=None, **kwargs):
self.proxy_target = OAuth2RefreshTokenGrant(
request_validator=request_validator, **kwargs)
self.register_token_modifier(self.add_id_token)

def add_id_token(self, token, token_handler, request):
"""
Construct an initial version of id_token, and let the
request_validator sign or encrypt it.
The authorization_code version of this method is used to
retrieve the nonce accordingly to the code storage.
"""
if not self.request_validator.refresh_id_token(request):
return token

return super().add_id_token(token, token_handler, request)
12 changes: 12 additions & 0 deletions oauthlib/openid/connect/core/request_validator.py
Expand Up @@ -306,3 +306,15 @@ def get_userinfo_claims(self, request):
Method is used by:
UserInfoEndpoint
"""

def refresh_id_token(self, request):
"""Whether the id token should be refreshed. Default, True
:param request: OAuthlib request.
:type request: oauthlib.common.Request
:rtype: True or False
Method is used by:
RefreshTokenGrant
"""
return True
105 changes: 105 additions & 0 deletions tests/openid/connect/core/grant_types/test_refresh_token.py
@@ -0,0 +1,105 @@
import json
from unittest import mock

from oauthlib.common import Request
from oauthlib.oauth2.rfc6749.tokens import BearerToken
from oauthlib.openid.connect.core.grant_types import RefreshTokenGrant

from tests.oauth2.rfc6749.grant_types.test_refresh_token import (
RefreshTokenGrantTest,
)
from tests.unittest import TestCase


def get_id_token_mock(token, token_handler, request):
return "MOCKED_TOKEN"


class OpenIDRefreshTokenInterferenceTest(RefreshTokenGrantTest):
"""Test that OpenID don't interfere with normal OAuth 2 flows."""

def setUp(self):
super().setUp()
self.auth = RefreshTokenGrant(request_validator=self.mock_validator)


class OpenIDRefreshTokenTest(TestCase):

def setUp(self):
self.request = Request('http://a.b/path')
self.request.grant_type = 'refresh_token'
self.request.refresh_token = 'lsdkfhj230'
self.request.scope = ('hello', 'openid')
self.mock_validator = mock.MagicMock()

self.mock_validator = mock.MagicMock()
self.mock_validator.authenticate_client.side_effect = self.set_client
self.mock_validator.get_id_token.side_effect = get_id_token_mock
self.auth = RefreshTokenGrant(request_validator=self.mock_validator)

def set_client(self, request):
request.client = mock.MagicMock()
request.client.client_id = 'mocked'
return True

def test_refresh_id_token(self):
self.mock_validator.get_original_scopes.return_value = [
'hello', 'openid'
]
bearer = BearerToken(self.mock_validator)

headers, body, status_code = self.auth.create_token_response(
self.request, bearer
)

token = json.loads(body)
self.assertEqual(self.mock_validator.save_token.call_count, 1)
self.assertIn('access_token', token)
self.assertIn('refresh_token', token)
self.assertIn('id_token', token)
self.assertIn('token_type', token)
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
self.mock_validator.refresh_id_token.assert_called_once_with(
self.request
)

def test_refresh_id_token_false(self):
self.mock_validator.refresh_id_token.return_value = False
self.mock_validator.get_original_scopes.return_value = [
'hello', 'openid'
]
bearer = BearerToken(self.mock_validator)

headers, body, status_code = self.auth.create_token_response(
self.request, bearer
)

token = json.loads(body)
self.assertEqual(self.mock_validator.save_token.call_count, 1)
self.assertIn('access_token', token)
self.assertIn('refresh_token', token)
self.assertIn('token_type', token)
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
self.assertNotIn('id_token', token)
self.mock_validator.refresh_id_token.assert_called_once_with(
self.request
)

def test_refresh_token_without_openid_scope(self):
self.request.scope = "hello"
bearer = BearerToken(self.mock_validator)

headers, body, status_code = self.auth.create_token_response(
self.request, bearer
)

token = json.loads(body)
self.assertEqual(self.mock_validator.save_token.call_count, 1)
self.assertIn('access_token', token)
self.assertIn('refresh_token', token)
self.assertIn('token_type', token)
self.assertIn('expires_in', token)
self.assertNotIn('id_token', token)
self.assertEqual(token['scope'], 'hello')

0 comments on commit 555e3b0

Please sign in to comment.