diff --git a/oauthlib/openid/connect/core/grant_types/refresh_token.py b/oauthlib/openid/connect/core/grant_types/refresh_token.py index 386b57cd..43e4499c 100644 --- a/oauthlib/openid/connect/core/grant_types/refresh_token.py +++ b/oauthlib/openid/connect/core/grant_types/refresh_token.py @@ -15,8 +15,7 @@ class RefreshTokenGrant(GrantTypeBase): - def __init__(self, refresh_id_token=True, request_validator=None, **kwargs): - self.refresh_id_token = refresh_id_token + def __init__(self, request_validator=None, **kwargs): self.proxy_target = OAuth2RefreshTokenGrant( request_validator=request_validator, **kwargs) self.register_token_modifier(self.add_id_token) @@ -29,8 +28,7 @@ def add_id_token(self, token, token_handler, request): The authorization_code version of this method is used to retrieve the nonce accordingly to the code storage. """ - # Treat it as normal OAuth 2 auth code request if openid is not present - if not self.refresh_id_token: + if not self.request_validator.refresh_id_token(request): return token return super().add_id_token(token, token_handler, request) diff --git a/oauthlib/openid/connect/core/request_validator.py b/oauthlib/openid/connect/core/request_validator.py index e8f334b0..47c4cd94 100644 --- a/oauthlib/openid/connect/core/request_validator.py +++ b/oauthlib/openid/connect/core/request_validator.py @@ -306,3 +306,15 @@ def get_userinfo_claims(self, request): Method is used by: UserInfoEndpoint """ + + def refresh_id_token(self, request): + """Whether the id token should be refreshed. Default, True + + :param request: OAuthlib request. + :type request: oauthlib.common.Request + :rtype: True or False + + Method is used by: + RefreshTokenGrant + """ + return True diff --git a/tests/openid/connect/core/grant_types/test_refresh_token.py b/tests/openid/connect/core/grant_types/test_refresh_token.py index c19de188..8126e1b8 100644 --- a/tests/openid/connect/core/grant_types/test_refresh_token.py +++ b/tests/openid/connect/core/grant_types/test_refresh_token.py @@ -60,9 +60,12 @@ def test_refresh_id_token(self): self.assertIn('token_type', token) self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'hello openid') + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) def test_refresh_id_token_false(self): - self.auth.refresh_id_token = False + self.mock_validator.refresh_id_token.return_value = False self.mock_validator.get_original_scopes.return_value = [ 'hello', 'openid' ] @@ -80,6 +83,9 @@ def test_refresh_id_token_false(self): self.assertIn('expires_in', token) self.assertEqual(token['scope'], 'hello openid') self.assertNotIn('id_token', token) + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) def test_refresh_token_without_openid_scope(self): self.request.scope = "hello"