diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index e137f23678e6..a32dbd4ea347 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -19,6 +19,7 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -97,10 +98,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { code := ctx.Query("code") if code == "" { - // FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go - // ApprovalForce is required to get the refresh token every time the user logs in with Google OAuth (without this the - // refresh token is only provided when the user first gives consent) - opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} + opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline} + + if hs.Features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { + // Change the defaults to AccessTypeOffline if accessTokenExpirationCheck is enabled + // and get the custom parameters from the specific OAuth connector + opts = connect.GetCustomAuthParams() + } if provider.UsePKCE { ascii, pkce, err := genPKCECode() diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index b51439494433..03e83f81b8a5 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -9,15 +9,15 @@ import ( "path/filepath" "testing" - "github.com/grafana/grafana/pkg/infra/db" - "github.com/grafana/grafana/pkg/services/secrets/fakes" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/licensing" + "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) @@ -39,6 +39,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux { SocialService: social.ProvideService(cfg), HooksService: hooks.ProvideService(), SecretsService: fakes.NewFakeSecretsService(), + Features: featuremgmt.WithFeatures(), } m := web.New() diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index 942530b8410c..e22a3f3c5e9a 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -50,6 +50,10 @@ func (s *SocialAzureAD) Type() int { return int(models.AZUREAD) } +func (s *SocialAzureAD) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index dd7b8376a693..f2e51b995533 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -36,6 +36,10 @@ func (s *SocialGenericOAuth) Type() int { return int(models.GENERIC) } +func (s *SocialGenericOAuth) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} +} + func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/github_oauth.go index 1636a5918ffa..9e83ecd7104f 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/github_oauth.go @@ -37,6 +37,10 @@ func (s *SocialGithub) Type() int { return int(models.GITHUB) } +func (s *SocialGithub) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGithub) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/gitlab_oauth.go index abb12ed46a05..b5d6010d1bab 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/gitlab_oauth.go @@ -21,6 +21,10 @@ func (s *SocialGitlab) Type() int { return int(models.GITLAB) } +func (s *SocialGitlab) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGitlab) IsGroupMember(groups []string) bool { if len(s.allowedGroups) == 0 { return true diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/google_oauth.go index e15834a45fbe..22a89723936f 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/google_oauth.go @@ -20,6 +20,10 @@ func (s *SocialGoogle) Type() int { return int(models.GOOGLE) } +func (s *SocialGoogle) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} +} + func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id string `json:"id"` diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/grafana_com_oauth.go index 95f08f29ed71..f777fa8db483 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/grafana_com_oauth.go @@ -25,6 +25,10 @@ func (s *SocialGrafanaCom) Type() int { return int(models.GRAFANA_COM) } +func (s *SocialGrafanaCom) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { return true } diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/okta_oauth.go index a78635189156..1bff89869d41 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/okta_oauth.go @@ -48,6 +48,10 @@ func (s *SocialOkta) Type() int { return int(models.OKTA) } +func (s *SocialOkta) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 788593f7fb40..576aca836b38 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -244,6 +244,7 @@ type SocialConnector interface { UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) IsEmailAllowed(email string) bool IsSignupAllowed() bool + GetCustomAuthParams() []oauth2.AuthCodeOption AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 179073faa525..3d608c569c31 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -443,20 +443,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org return false } - getTime := h.GetTime - if getTime == nil { - getTime = time.Now - } - if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { // Check whether the logged in User has a token (whether the User used an OAuth provider to login) oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult) if exists { - // Skip where the OAuthExpiry is default/zero/unset if h.hasAccessTokenExpired(oauthToken) { reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry)) - // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens + // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil { if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err) diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index ae761f0580f3..aa1bf3d1df67 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -290,6 +290,10 @@ func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeO panic("not implemented") } +func (m *MockSocialConnector) GetCustomAuthParams() []oauth2.AuthCodeOption { + panic("not implemented") +} + func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { panic("not implemented") }