Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyongyosi committed Nov 17, 2022
1 parent f916aa2 commit 70b88a7
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 14 deletions.
12 changes: 8 additions & 4 deletions pkg/api/login_oauth.go
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
Expand Down Expand Up @@ -97,10 +98,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {

code := ctx.Query("code")
if code == "" {
// FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go
// ApprovalForce is required to get the refresh token every time the user logs in with Google OAuth (without this the
// refresh token is only provided when the user first gives consent)
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce}
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOnline}

if hs.Features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
// Change the defaults to AccessTypeOffline if accessTokenExpirationCheck is enabled
// and get the custom parameters from the specific OAuth connector
opts = connect.GetCustomAuthParams()
}

if provider.UsePKCE {
ascii, pkce, err := genPKCECode()
Expand Down
7 changes: 4 additions & 3 deletions pkg/api/login_oauth_test.go
Expand Up @@ -9,15 +9,15 @@ import (
"path/filepath"
"testing"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/secrets/fakes"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
Expand All @@ -39,6 +39,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
SocialService: social.ProvideService(cfg),
HooksService: hooks.ProvideService(),
SecretsService: fakes.NewFakeSecretsService(),
Features: featuremgmt.WithFeatures(),
}

m := web.New()
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/azuread_oauth.go
Expand Up @@ -50,6 +50,10 @@ func (s *SocialAzureAD) Type() int {
return int(models.AZUREAD)
}

func (s *SocialAzureAD) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token")
if idToken == nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/generic_oauth.go
Expand Up @@ -36,6 +36,10 @@ func (s *SocialGenericOAuth) Type() int {
return int(models.GENERIC)
}

func (s *SocialGenericOAuth) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}
}

func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
if len(s.teamIds) == 0 {
return true
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/github_oauth.go
Expand Up @@ -37,6 +37,10 @@ func (s *SocialGithub) Type() int {
return int(models.GITHUB)
}

func (s *SocialGithub) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
if len(s.teamIds) == 0 {
return true
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/gitlab_oauth.go
Expand Up @@ -21,6 +21,10 @@ func (s *SocialGitlab) Type() int {
return int(models.GITLAB)
}

func (s *SocialGitlab) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (s *SocialGitlab) IsGroupMember(groups []string) bool {
if len(s.allowedGroups) == 0 {
return true
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/google_oauth.go
Expand Up @@ -20,6 +20,10 @@ func (s *SocialGoogle) Type() int {
return int(models.GOOGLE)
}

func (s *SocialGoogle) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.AccessTypeOffline, oauth2.ApprovalForce}
}

func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id string `json:"id"`
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/grafana_com_oauth.go
Expand Up @@ -25,6 +25,10 @@ func (s *SocialGrafanaCom) Type() int {
return int(models.GRAFANA_COM)
}

func (s *SocialGrafanaCom) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (s *SocialGrafanaCom) IsEmailAllowed(email string) bool {
return true
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/login/social/okta_oauth.go
Expand Up @@ -48,6 +48,10 @@ func (s *SocialOkta) Type() int {
return int(models.OKTA)
}

func (s *SocialOkta) GetCustomAuthParams() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token")
if idToken == nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/login/social/social.go
Expand Up @@ -244,6 +244,7 @@ type SocialConnector interface {
UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
IsEmailAllowed(email string) bool
IsSignupAllowed() bool
GetCustomAuthParams() []oauth2.AuthCodeOption

AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error)
Expand Down
8 changes: 1 addition & 7 deletions pkg/services/contexthandler/contexthandler.go
Expand Up @@ -443,20 +443,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
return false
}

getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}

if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
if exists {
// Skip where the OAuthExpiry is default/zero/unset
if h.hasAccessTokenExpired(oauthToken) {
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))

// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
Expand Down
4 changes: 4 additions & 0 deletions pkg/services/oauthtoken/oauth_token_test.go
Expand Up @@ -290,6 +290,10 @@ func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeO
panic("not implemented")
}

func (m *MockSocialConnector) GetCustomAuthParams() []oauth2.AuthCodeOption {
panic("not implemented")
}

func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
panic("not implemented")
}
Expand Down

0 comments on commit 70b88a7

Please sign in to comment.