Skip to content

Commit

Permalink
Move refresh_id_token to validator function
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas authored and achraf-mer committed Oct 21, 2021
1 parent a82a907 commit 30df86d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
6 changes: 2 additions & 4 deletions oauthlib/openid/connect/core/grant_types/refresh_token.py
Expand Up @@ -15,8 +15,7 @@

class RefreshTokenGrant(GrantTypeBase):

def __init__(self, refresh_id_token=True, request_validator=None, **kwargs):
self.refresh_id_token = refresh_id_token
def __init__(self, request_validator=None, **kwargs):
self.proxy_target = OAuth2RefreshTokenGrant(
request_validator=request_validator, **kwargs)
self.register_token_modifier(self.add_id_token)
Expand All @@ -29,8 +28,7 @@ def add_id_token(self, token, token_handler, request):
The authorization_code version of this method is used to
retrieve the nonce accordingly to the code storage.
"""
# Treat it as normal OAuth 2 auth code request if openid is not present
if not self.refresh_id_token:
if not self.request_validator.refresh_id_token(request):
return token

return super().add_id_token(token, token_handler, request)
12 changes: 12 additions & 0 deletions oauthlib/openid/connect/core/request_validator.py
Expand Up @@ -306,3 +306,15 @@ def get_userinfo_claims(self, request):
Method is used by:
UserInfoEndpoint
"""

def refresh_id_token(self, request):
"""Whether the id token should be refreshed. Default, True
:param request: OAuthlib request.
:type request: oauthlib.common.Request
:rtype: True or False
Method is used by:
RefreshTokenGrant
"""
return True
8 changes: 7 additions & 1 deletion tests/openid/connect/core/grant_types/test_refresh_token.py
Expand Up @@ -60,9 +60,12 @@ def test_refresh_id_token(self):
self.assertIn('token_type', token)
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
self.mock_validator.refresh_id_token.assert_called_once_with(
self.request
)

def test_refresh_id_token_false(self):
self.auth.refresh_id_token = False
self.mock_validator.refresh_id_token.return_value = False
self.mock_validator.get_original_scopes.return_value = [
'hello', 'openid'
]
Expand All @@ -80,6 +83,9 @@ def test_refresh_id_token_false(self):
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
self.assertNotIn('id_token', token)
self.mock_validator.refresh_id_token.assert_called_once_with(
self.request
)

def test_refresh_token_without_openid_scope(self):
self.request.scope = "hello"
Expand Down

0 comments on commit 30df86d

Please sign in to comment.