Skip to content

Commit

Permalink
OAuth: Refactor OAuth parameters handling to support obtaining refres…
Browse files Browse the repository at this point in the history
…h tokens for Google OAuth (#58782)

* Add ApprovalForce to AuthCodeOptions

* Extract access token validity check to a function

* Refactor

* Oauth: set options internally instead of exposing new function

* Align tests

* Remove unused function

Co-authored-by: Karl Persson <kalle.persson@grafana.com>
  • Loading branch information
mgyongyosi and kalleep committed Nov 18, 2022
1 parent d46e391 commit 9c98314
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 40 deletions.
2 changes: 1 addition & 1 deletion pkg/api/frontendsettings_test.go
Expand Up @@ -58,7 +58,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt.
grafanaUpdateChecker: &updatechecker.GrafanaService{},
AccessControl: accesscontrolmock.New().WithDisabled(),
PluginSettings: pluginSettings.ProvideService(sqlStore, secretsService),
SocialService: social.ProvideService(cfg),
SocialService: social.ProvideService(cfg, features),
}

m := web.New()
Expand Down
4 changes: 1 addition & 3 deletions pkg/api/login_oauth.go
Expand Up @@ -97,9 +97,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {

code := ctx.Query("code")
if code == "" {
// FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}

var opts []oauth2.AuthCodeOption
if provider.UsePKCE {
ascii, pkce, err := genPKCECode()
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/api/login_oauth_test.go
Expand Up @@ -9,15 +9,15 @@ import (
"path/filepath"
"testing"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/secrets/fakes"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
Expand All @@ -36,7 +36,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
Cfg: cfg,
License: &licensing.OSSLicensingService{Cfg: cfg},
SQLStore: sqlStore,
SocialService: social.ProvideService(cfg),
SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()),
HooksService: hooks.ProvideService(),
SecretsService: fakes.NewFakeSecretsService(),
}
Expand Down
28 changes: 15 additions & 13 deletions pkg/login/social/azuread_oauth_test.go
Expand Up @@ -13,6 +13,8 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

func trueBoolPtr() *bool {
Expand Down Expand Up @@ -54,7 +56,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -93,7 +95,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -143,7 +145,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Only other roles",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand Down Expand Up @@ -171,7 +173,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -220,7 +222,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Grafana Admin but setting is disabled",
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false)},
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand All @@ -242,7 +244,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
fields: fields{
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand All @@ -263,7 +265,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Grafana Admin and Editor roles in claim",
fields: fields{SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand Down Expand Up @@ -302,7 +304,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
fields: fields{
allowedGroups: []string{"foo", "bar"},
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false),
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -324,7 +326,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
ID: "1",
Expand All @@ -349,7 +351,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when forceUseGraphAPI is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
forceUseGraphAPI: true,
},
claims: &azureClaims{
Expand All @@ -376,7 +378,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no match",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -392,7 +394,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no role claims returned",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -416,7 +418,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
}

if tt.fields.SocialBase == nil {
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false)
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
}

key := []byte("secret")
Expand Down
9 changes: 9 additions & 0 deletions pkg/login/social/generic_oauth.go
Expand Up @@ -14,6 +14,8 @@ import (
"strconv"

"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

type SocialGenericOAuth struct {
Expand Down Expand Up @@ -504,3 +506,10 @@ func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string,

return logins, true
}

func (s *SocialGenericOAuth) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
opts = append(opts, oauth2.AccessTypeOffline)
}
return s.SocialBase.AuthCodeURL(state, opts...)
}
4 changes: 3 additions & 1 deletion pkg/login/social/github_oauth_test.go
Expand Up @@ -9,6 +9,8 @@ import (

"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

const testGHUserTeamsJSON = `[
Expand Down Expand Up @@ -202,7 +204,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {

s := &SocialGithub{
SocialBase: newSocialBase("github", &oauth2.Config{},
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false),
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false, *featuremgmt.WithFeatures()),
allowedOrganizations: []string{},
apiUrl: server.URL + "/user",
teamIds: []int{},
Expand Down
9 changes: 9 additions & 0 deletions pkg/login/social/google_oauth.go
Expand Up @@ -6,6 +6,8 @@ import (
"net/http"

"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

type SocialGoogle struct {
Expand Down Expand Up @@ -38,3 +40,10 @@ func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
Login: data.Email,
}, nil
}

func (s *SocialGoogle) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
}
return s.SocialBase.AuthCodeURL(state, opts...)
}
21 changes: 12 additions & 9 deletions pkg/login/social/social.go
Expand Up @@ -16,6 +16,7 @@ import (
"golang.org/x/text/language"

"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
Expand Down Expand Up @@ -58,7 +59,7 @@ type OAuthInfo struct {
UsePKCE bool
}

func ProvideService(cfg *setting.Cfg) *SocialService {
func ProvideService(cfg *setting.Cfg, features *featuremgmt.FeatureManager) *SocialService {
ss := SocialService{
cfg: cfg,
oAuthProvider: make(map[string]*OAuthInfo),
Expand Down Expand Up @@ -139,7 +140,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// GitHub.
if name == "github" {
ss.socialMap["github"] = &SocialGithub{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
Expand All @@ -149,7 +150,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// GitLab.
if name == "gitlab" {
ss.socialMap["gitlab"] = &SocialGitlab{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
Expand All @@ -158,7 +159,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Google.
if name == "google" {
ss.socialMap["google"] = &SocialGoogle{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
hostedDomain: info.HostedDomain,
apiUrl: info.ApiUrl,
}
Expand All @@ -167,7 +168,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// AzureAD.
if name == "azuread" {
ss.socialMap["azuread"] = &SocialAzureAD{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
}
Expand All @@ -176,7 +177,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Okta
if name == "okta" {
ss.socialMap["okta"] = &SocialOkta{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
Expand All @@ -185,7 +186,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Generic - Uses the same scheme as GitHub.
if name == "generic_oauth" {
ss.socialMap["generic_oauth"] = &SocialGenericOAuth{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
teamsUrl: info.TeamsUrl,
emailAttributeName: info.EmailAttributeName,
Expand Down Expand Up @@ -214,8 +215,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
}

ss.socialMap[grafanaCom] = &SocialGrafanaCom{
SocialBase: newSocialBase(name, &config, info,
cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
url: cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
Expand Down Expand Up @@ -261,6 +261,7 @@ type SocialBase struct {
roleAttributeStrict bool
autoAssignOrgRole string
skipOrgRoleSync bool
features featuremgmt.FeatureManager
}

type Error struct {
Expand Down Expand Up @@ -295,6 +296,7 @@ func newSocialBase(name string,
info *OAuthInfo,
autoAssignOrgRole string,
skipOrgRoleSync bool,
features featuremgmt.FeatureManager,
) *SocialBase {
logger := log.New("oauth." + name)

Expand All @@ -308,6 +310,7 @@ func newSocialBase(name string,
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
skipOrgRoleSync: skipOrgRoleSync,
features: features,
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server.go
Expand Up @@ -127,7 +127,7 @@ func (s *Server) init() error {
}

login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService)
social.ProvideService(s.cfg)
social.ProvideService(s.cfg, s.HTTPServer.Features)

if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil {
return err
Expand Down
23 changes: 15 additions & 8 deletions pkg/services/contexthandler/contexthandler.go
Expand Up @@ -449,20 +449,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
return false
}

getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}

if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
if exists {
// Skip where the OAuthExpiry is default/zero/unset
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
if h.hasAccessTokenExpired(oauthToken) {
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))

// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
Expand Down Expand Up @@ -732,3 +726,16 @@ func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList {
}
return nil
}

func (h *ContextHandler) hasAccessTokenExpired(token *models.UserAuth) bool {
if token.OAuthExpiry.IsZero() {
return false
}

getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}

return token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime())
}

0 comments on commit 9c98314

Please sign in to comment.