From f916aa23db28d582c08466dd406b6b7b606c6588 Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Tue, 15 Nov 2022 18:28:53 +0100 Subject: [PATCH 1/5] Add ApprovalForce to AuthCodeOptions * Extract access token validity check to a function --- pkg/api/login_oauth.go | 4 +++- pkg/services/contexthandler/contexthandler.go | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 88a465d7fdb9..e137f23678e6 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -98,7 +98,9 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { code := ctx.Query("code") if code == "" { // FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go - opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} + // ApprovalForce is required to get the refresh token every time the user logs in with Google OAuth (without this the + // refresh token is only provided when the user first gives consent) + opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} if provider.UsePKCE { ascii, pkce, err := genPKCECode() diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 0179a8bccf00..179073faa525 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -453,7 +453,7 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult) if exists { // Skip where the OAuthExpiry is default/zero/unset - if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) { + if h.hasAccessTokenExpired(oauthToken) { reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry)) // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens @@ -726,3 +726,16 @@ func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList { } return nil } + +func (h *ContextHandler) hasAccessTokenExpired(token *models.UserAuth) bool { + if token.OAuthExpiry.IsZero() { + return false + } + + getTime := h.GetTime + if getTime == nil { + getTime = time.Now + } + + return token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) +} From 70b88a7e250eb264e00beb859421c9fe8c168f9c Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Wed, 16 Nov 2022 18:37:02 +0100 Subject: [PATCH 2/5] Refactor --- pkg/api/login_oauth.go | 12 ++++++++---- pkg/api/login_oauth_test.go | 7 ++++--- pkg/login/social/azuread_oauth.go | 4 ++++ pkg/login/social/generic_oauth.go | 4 ++++ pkg/login/social/github_oauth.go | 4 ++++ pkg/login/social/gitlab_oauth.go | 4 ++++ pkg/login/social/google_oauth.go | 4 ++++ pkg/login/social/grafana_com_oauth.go | 4 ++++ pkg/login/social/okta_oauth.go | 4 ++++ pkg/login/social/social.go | 1 + pkg/services/contexthandler/contexthandler.go | 8 +------- pkg/services/oauthtoken/oauth_token_test.go | 4 ++++ 12 files changed, 46 insertions(+), 14 deletions(-) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index e137f23678e6..a32dbd4ea347 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -19,6 +19,7 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -97,10 +98,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { code := ctx.Query("code") if code == "" { - // FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go - // ApprovalForce is required to get the refresh token every time the user logs in with Google OAuth (without this the - // refresh token is only provided when the user first gives consent) - opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} + opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline} + + if hs.Features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { + // Change the defaults to AccessTypeOffline if accessTokenExpirationCheck is enabled + // and get the custom parameters from the specific OAuth connector + opts = connect.GetCustomAuthParams() + } if provider.UsePKCE { ascii, pkce, err := genPKCECode() diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index b51439494433..03e83f81b8a5 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -9,15 +9,15 @@ import ( "path/filepath" "testing" - "github.com/grafana/grafana/pkg/infra/db" - "github.com/grafana/grafana/pkg/services/secrets/fakes" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/licensing" + "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) @@ -39,6 +39,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux { SocialService: social.ProvideService(cfg), HooksService: hooks.ProvideService(), SecretsService: fakes.NewFakeSecretsService(), + Features: featuremgmt.WithFeatures(), } m := web.New() diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index 942530b8410c..e22a3f3c5e9a 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -50,6 +50,10 @@ func (s *SocialAzureAD) Type() int { return int(models.AZUREAD) } +func (s *SocialAzureAD) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index dd7b8376a693..f2e51b995533 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -36,6 +36,10 @@ func (s *SocialGenericOAuth) Type() int { return int(models.GENERIC) } +func (s *SocialGenericOAuth) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} +} + func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/github_oauth.go index 1636a5918ffa..9e83ecd7104f 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/github_oauth.go @@ -37,6 +37,10 @@ func (s *SocialGithub) Type() int { return int(models.GITHUB) } +func (s *SocialGithub) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGithub) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/gitlab_oauth.go index abb12ed46a05..b5d6010d1bab 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/gitlab_oauth.go @@ -21,6 +21,10 @@ func (s *SocialGitlab) Type() int { return int(models.GITLAB) } +func (s *SocialGitlab) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGitlab) IsGroupMember(groups []string) bool { if len(s.allowedGroups) == 0 { return true diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/google_oauth.go index e15834a45fbe..22a89723936f 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/google_oauth.go @@ -20,6 +20,10 @@ func (s *SocialGoogle) Type() int { return int(models.GOOGLE) } +func (s *SocialGoogle) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} +} + func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id string `json:"id"` diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/grafana_com_oauth.go index 95f08f29ed71..f777fa8db483 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/grafana_com_oauth.go @@ -25,6 +25,10 @@ func (s *SocialGrafanaCom) Type() int { return int(models.GRAFANA_COM) } +func (s *SocialGrafanaCom) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { return true } diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/okta_oauth.go index a78635189156..1bff89869d41 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/okta_oauth.go @@ -48,6 +48,10 @@ func (s *SocialOkta) Type() int { return int(models.OKTA) } +func (s *SocialOkta) GetCustomAuthParams() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{} +} + func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 788593f7fb40..576aca836b38 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -244,6 +244,7 @@ type SocialConnector interface { UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) IsEmailAllowed(email string) bool IsSignupAllowed() bool + GetCustomAuthParams() []oauth2.AuthCodeOption AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 179073faa525..3d608c569c31 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -443,20 +443,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org return false } - getTime := h.GetTime - if getTime == nil { - getTime = time.Now - } - if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { // Check whether the logged in User has a token (whether the User used an OAuth provider to login) oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult) if exists { - // Skip where the OAuthExpiry is default/zero/unset if h.hasAccessTokenExpired(oauthToken) { reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry)) - // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens + // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil { if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err) diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index ae761f0580f3..aa1bf3d1df67 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -290,6 +290,10 @@ func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeO panic("not implemented") } +func (m *MockSocialConnector) GetCustomAuthParams() []oauth2.AuthCodeOption { + panic("not implemented") +} + func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { panic("not implemented") } From 4ce845bee594ac8fb9baa27617832cc6618033d8 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Thu, 17 Nov 2022 14:49:08 +0100 Subject: [PATCH 3/5] Oauth: set options internally instead of exposing new function --- pkg/api/login_oauth.go | 10 +--------- pkg/login/social/generic_oauth.go | 12 ++++++++---- pkg/login/social/github_oauth.go | 4 ---- pkg/login/social/gitlab_oauth.go | 4 ---- pkg/login/social/google_oauth.go | 12 ++++++++---- pkg/login/social/grafana_com_oauth.go | 4 ---- pkg/login/social/okta_oauth.go | 4 ---- pkg/login/social/social.go | 22 ++++++++++++---------- 8 files changed, 29 insertions(+), 43 deletions(-) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index a32dbd4ea347..603f158611a9 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -19,7 +19,6 @@ import ( "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -98,14 +97,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { code := ctx.Query("code") if code == "" { - opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline} - - if hs.Features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { - // Change the defaults to AccessTypeOffline if accessTokenExpirationCheck is enabled - // and get the custom parameters from the specific OAuth connector - opts = connect.GetCustomAuthParams() - } - + var opts []oauth2.AuthCodeOption if provider.UsePKCE { ascii, pkce, err := genPKCECode() if err != nil { diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index f2e51b995533..e4bcae5b26ce 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -14,6 +14,7 @@ import ( "strconv" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/featuremgmt" "golang.org/x/oauth2" ) @@ -36,10 +37,6 @@ func (s *SocialGenericOAuth) Type() int { return int(models.GENERIC) } -func (s *SocialGenericOAuth) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} -} - func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true @@ -513,3 +510,10 @@ func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, return logins, true } + +func (s *SocialGenericOAuth) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { + opts = append(opts, oauth2.AccessTypeOffline) + } + return s.SocialBase.AuthCodeURL(state, opts...) +} diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/github_oauth.go index 9e83ecd7104f..1636a5918ffa 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/github_oauth.go @@ -37,10 +37,6 @@ func (s *SocialGithub) Type() int { return int(models.GITHUB) } -func (s *SocialGithub) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{} -} - func (s *SocialGithub) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/gitlab_oauth.go index b5d6010d1bab..abb12ed46a05 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/gitlab_oauth.go @@ -21,10 +21,6 @@ func (s *SocialGitlab) Type() int { return int(models.GITLAB) } -func (s *SocialGitlab) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{} -} - func (s *SocialGitlab) IsGroupMember(groups []string) bool { if len(s.allowedGroups) == 0 { return true diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/google_oauth.go index 22a89723936f..a9eda4a006ef 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/google_oauth.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/featuremgmt" "golang.org/x/oauth2" ) @@ -20,10 +21,6 @@ func (s *SocialGoogle) Type() int { return int(models.GOOGLE) } -func (s *SocialGoogle) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce} -} - func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id string `json:"id"` @@ -48,3 +45,10 @@ func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*Basi Login: data.Email, }, nil } + +func (s *SocialGoogle) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) { + opts = append(opts, oauth2.AccessTypeOffline, oauth2.ApprovalForce) + } + return s.SocialBase.AuthCodeURL(state, opts...) +} diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/grafana_com_oauth.go index f777fa8db483..95f08f29ed71 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/grafana_com_oauth.go @@ -25,10 +25,6 @@ func (s *SocialGrafanaCom) Type() int { return int(models.GRAFANA_COM) } -func (s *SocialGrafanaCom) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{} -} - func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool { return true } diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/okta_oauth.go index 1bff89869d41..a78635189156 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/okta_oauth.go @@ -48,10 +48,6 @@ func (s *SocialOkta) Type() int { return int(models.OKTA) } -func (s *SocialOkta) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{} -} - func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 576aca836b38..0ab3d9c2eefc 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -16,6 +16,7 @@ import ( "golang.org/x/text/language" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" @@ -58,7 +59,7 @@ type OAuthInfo struct { UsePKCE bool } -func ProvideService(cfg *setting.Cfg) *SocialService { +func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *SocialService { ss := SocialService{ cfg: cfg, oAuthProvider: make(map[string]*OAuthInfo), @@ -139,7 +140,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // GitHub. if name == "github" { ss.socialMap["github"] = &SocialGithub{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), apiUrl: info.ApiUrl, teamIds: sec.Key("team_ids").Ints(","), allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), @@ -149,7 +150,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // GitLab. if name == "gitlab" { ss.socialMap["gitlab"] = &SocialGitlab{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), apiUrl: info.ApiUrl, allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), } @@ -158,7 +159,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // Google. if name == "google" { ss.socialMap["google"] = &SocialGoogle{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), hostedDomain: info.HostedDomain, apiUrl: info.ApiUrl, } @@ -167,7 +168,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // AzureAD. if name == "azuread" { ss.socialMap["azuread"] = &SocialAzureAD{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false), } @@ -176,7 +177,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // Okta if name == "okta" { ss.socialMap["okta"] = &SocialOkta{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), apiUrl: info.ApiUrl, allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), } @@ -185,7 +186,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { // Generic - Uses the same scheme as GitHub. if name == "generic_oauth" { ss.socialMap["generic_oauth"] = &SocialGenericOAuth{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), apiUrl: info.ApiUrl, teamsUrl: info.TeamsUrl, emailAttributeName: info.EmailAttributeName, @@ -214,8 +215,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService { } ss.socialMap[grafanaCom] = &SocialGrafanaCom{ - SocialBase: newSocialBase(name, &config, info, - cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), url: cfg.GrafanaComURL, allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), } @@ -244,7 +244,6 @@ type SocialConnector interface { UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) IsEmailAllowed(email string) bool IsSignupAllowed() bool - GetCustomAuthParams() []oauth2.AuthCodeOption AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) @@ -263,6 +262,7 @@ type SocialBase struct { roleAttributeStrict bool autoAssignOrgRole string skipOrgRoleSync bool + features featuremgmt.FeatureManager } type Error struct { @@ -297,6 +297,7 @@ func newSocialBase(name string, info *OAuthInfo, autoAssignOrgRole string, skipOrgRoleSync bool, + features featuremgmt.FeatureManager, ) *SocialBase { logger := log.New("oauth." + name) @@ -310,6 +311,7 @@ func newSocialBase(name string, roleAttributePath: info.RoleAttributePath, roleAttributeStrict: info.RoleAttributeStrict, skipOrgRoleSync: skipOrgRoleSync, + features: features, } } From 0d28c166776ac68b353017ddd8427386400756e9 Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Thu, 17 Nov 2022 16:10:16 +0100 Subject: [PATCH 4/5] Align tests --- pkg/api/frontendsettings_test.go | 2 +- pkg/api/login_oauth_test.go | 2 +- pkg/login/social/azuread_oauth.go | 4 ---- pkg/login/social/azuread_oauth_test.go | 28 ++++++++++++++------------ pkg/login/social/github_oauth_test.go | 4 +++- pkg/login/social/social.go | 16 +++++++-------- pkg/server/server.go | 2 +- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/pkg/api/frontendsettings_test.go b/pkg/api/frontendsettings_test.go index ef37a0765238..dc1d1446a43c 100644 --- a/pkg/api/frontendsettings_test.go +++ b/pkg/api/frontendsettings_test.go @@ -58,7 +58,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt. grafanaUpdateChecker: &updatechecker.GrafanaService{}, AccessControl: accesscontrolmock.New().WithDisabled(), PluginSettings: pluginSettings.ProvideService(sqlStore, secretsService), - SocialService: social.ProvideService(cfg), + SocialService: social.ProvideService(cfg, features), } m := web.New() diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index 03e83f81b8a5..3d77f2b2aa0b 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -36,7 +36,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux { Cfg: cfg, License: &licensing.OSSLicensingService{Cfg: cfg}, SQLStore: sqlStore, - SocialService: social.ProvideService(cfg), + SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()), HooksService: hooks.ProvideService(), SecretsService: fakes.NewFakeSecretsService(), Features: featuremgmt.WithFeatures(), diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index e22a3f3c5e9a..942530b8410c 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -50,10 +50,6 @@ func (s *SocialAzureAD) Type() int { return int(models.AZUREAD) } -func (s *SocialAzureAD) GetCustomAuthParams() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{} -} - func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { diff --git a/pkg/login/social/azuread_oauth_test.go b/pkg/login/social/azuread_oauth_test.go index ea632e9457b1..4a03dc12002d 100644 --- a/pkg/login/social/azuread_oauth_test.go +++ b/pkg/login/social/azuread_oauth_test.go @@ -13,6 +13,8 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" + + "github.com/grafana/grafana/pkg/services/featuremgmt" ) func trueBoolPtr() *bool { @@ -54,7 +56,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()), }, want: &BasicUserInfo{ Id: "1234", @@ -93,7 +95,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()), }, want: &BasicUserInfo{ Id: "1234", @@ -143,7 +145,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Only other roles", fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()), }, claims: &azureClaims{ Email: "me@example.com", @@ -171,7 +173,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { ID: "1234", }, fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()), }, want: &BasicUserInfo{ Id: "1234", @@ -220,7 +222,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { }, { name: "Grafana Admin but setting is disabled", - fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false)}, + fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())}, claims: &azureClaims{ Email: "me@example.com", PreferredUsername: "", @@ -242,7 +244,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { name: "Editor roles in claim and GrafanaAdminAssignment enabled", fields: fields{ SocialBase: newSocialBase("azuread", - &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)}, + &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())}, claims: &azureClaims{ Email: "me@example.com", PreferredUsername: "", @@ -263,7 +265,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Grafana Admin and Editor roles in claim", fields: fields{SocialBase: newSocialBase("azuread", - &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)}, + &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())}, claims: &azureClaims{ Email: "me@example.com", PreferredUsername: "", @@ -302,7 +304,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { fields: fields{ allowedGroups: []string{"foo", "bar"}, SocialBase: newSocialBase("azuread", - &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false), + &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()), }, claims: &azureClaims{ Email: "me@example.com", @@ -324,7 +326,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch groups when ClaimsNames and ClaimsSources is set", fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()), }, claims: &azureClaims{ ID: "1", @@ -349,7 +351,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch groups when forceUseGraphAPI is set", fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()), forceUseGraphAPI: true, }, claims: &azureClaims{ @@ -376,7 +378,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch empty role when strict attribute role is true and no match", fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()), }, claims: &azureClaims{ Email: "me@example.com", @@ -392,7 +394,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { { name: "Fetch empty role when strict attribute role is true and no role claims returned", fields: fields{ - SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false), + SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()), }, claims: &azureClaims{ Email: "me@example.com", @@ -416,7 +418,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { } if tt.fields.SocialBase == nil { - s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false) + s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()) } key := []byte("secret") diff --git a/pkg/login/social/github_oauth_test.go b/pkg/login/social/github_oauth_test.go index f610bd5843c7..cdb400b15ea9 100644 --- a/pkg/login/social/github_oauth_test.go +++ b/pkg/login/social/github_oauth_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + + "github.com/grafana/grafana/pkg/services/featuremgmt" ) const testGHUserTeamsJSON = `[ @@ -202,7 +204,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { s := &SocialGithub{ SocialBase: newSocialBase("github", &oauth2.Config{}, - &OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false), + &OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false, *featuremgmt.WithFeatures()), allowedOrganizations: []string{}, apiUrl: server.URL + "/user", teamIds: []int{}, diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 0ab3d9c2eefc..ceb837ac055d 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -59,7 +59,7 @@ type OAuthInfo struct { UsePKCE bool } -func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *SocialService { +func ProvideService(cfg *setting.Cfg, features *featuremgmt.FeatureManager) *SocialService { ss := SocialService{ cfg: cfg, oAuthProvider: make(map[string]*OAuthInfo), @@ -140,7 +140,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // GitHub. if name == "github" { ss.socialMap["github"] = &SocialGithub{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, teamIds: sec.Key("team_ids").Ints(","), allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), @@ -150,7 +150,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // GitLab. if name == "gitlab" { ss.socialMap["gitlab"] = &SocialGitlab{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), } @@ -159,7 +159,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // Google. if name == "google" { ss.socialMap["google"] = &SocialGoogle{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), hostedDomain: info.HostedDomain, apiUrl: info.ApiUrl, } @@ -168,7 +168,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // AzureAD. if name == "azuread" { ss.socialMap["azuread"] = &SocialAzureAD{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false), } @@ -177,7 +177,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // Okta if name == "okta" { ss.socialMap["okta"] = &SocialOkta{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, allowedGroups: util.SplitString(sec.Key("allowed_groups").String()), } @@ -186,7 +186,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci // Generic - Uses the same scheme as GitHub. if name == "generic_oauth" { ss.socialMap["generic_oauth"] = &SocialGenericOAuth{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), apiUrl: info.ApiUrl, teamsUrl: info.TeamsUrl, emailAttributeName: info.EmailAttributeName, @@ -215,7 +215,7 @@ func ProvideService(cfg *setting.Cfg, features featuremgmt.FeatureManager) *Soci } ss.socialMap[grafanaCom] = &SocialGrafanaCom{ - SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, features), + SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features), url: cfg.GrafanaComURL, allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()), } diff --git a/pkg/server/server.go b/pkg/server/server.go index aec724df9794..4b80055b009e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -127,7 +127,7 @@ func (s *Server) init() error { } login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService) - social.ProvideService(s.cfg) + social.ProvideService(s.cfg, s.HTTPServer.Features) if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil { return err From 92540f7cc84327821a4fc7cb2aeb8633d0f8968e Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Thu, 17 Nov 2022 17:03:33 +0100 Subject: [PATCH 5/5] Remove unused function --- pkg/api/login_oauth_test.go | 1 - pkg/services/oauthtoken/oauth_token_test.go | 4 ---- 2 files changed, 5 deletions(-) diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index 3d77f2b2aa0b..f39c5cda96c5 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -39,7 +39,6 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux { SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()), HooksService: hooks.ProvideService(), SecretsService: fakes.NewFakeSecretsService(), - Features: featuremgmt.WithFeatures(), } m := web.New() diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index aa1bf3d1df67..ae761f0580f3 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -290,10 +290,6 @@ func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeO panic("not implemented") } -func (m *MockSocialConnector) GetCustomAuthParams() []oauth2.AuthCodeOption { - panic("not implemented") -} - func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { panic("not implemented") }