Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth: Refactor OAuth parameters handling to support obtaining refresh tokens for Google OAuth #58782

Merged
merged 6 commits into from Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/api/frontendsettings_test.go
Expand Up @@ -58,7 +58,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt.
grafanaUpdateChecker: &updatechecker.GrafanaService{},
AccessControl: accesscontrolmock.New().WithDisabled(),
PluginSettings: pluginSettings.ProvideService(sqlStore, secretsService),
SocialService: social.ProvideService(cfg),
SocialService: social.ProvideService(cfg, features),
}

m := web.New()
Expand Down
4 changes: 1 addition & 3 deletions pkg/api/login_oauth.go
Expand Up @@ -97,9 +97,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {

code := ctx.Query("code")
if code == "" {
// FIXME: access_type is a Google OAuth2 specific thing, consider refactoring this and moving to google_oauth.go
opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}

var opts []oauth2.AuthCodeOption
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these last changes, the extra AuthCodeOptions that were added to gitlab,github and azure will be removed regardless of the feature flag.

From our discussion I think that's fine since it's not used by these providers but it's a small increase in risk compared to the previous version that we should be aware of

if provider.UsePKCE {
ascii, pkce, err := genPKCECode()
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/api/login_oauth_test.go
Expand Up @@ -9,15 +9,15 @@ import (
"path/filepath"
"testing"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/secrets/fakes"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
Expand All @@ -36,7 +36,7 @@ func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
Cfg: cfg,
License: &licensing.OSSLicensingService{Cfg: cfg},
SQLStore: sqlStore,
SocialService: social.ProvideService(cfg),
SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()),
HooksService: hooks.ProvideService(),
SecretsService: fakes.NewFakeSecretsService(),
}
Expand Down
28 changes: 15 additions & 13 deletions pkg/login/social/azuread_oauth_test.go
Expand Up @@ -13,6 +13,8 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

func trueBoolPtr() *bool {
Expand Down Expand Up @@ -54,7 +56,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -93,7 +95,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -143,7 +145,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Only other roles",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand Down Expand Up @@ -171,7 +173,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Expand Down Expand Up @@ -220,7 +222,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Grafana Admin but setting is disabled",
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false)},
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand All @@ -242,7 +244,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
fields: fields{
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand All @@ -263,7 +265,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Grafana Admin and Editor roles in claim",
fields: fields{SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false)},
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Expand Down Expand Up @@ -302,7 +304,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
fields: fields{
allowedGroups: []string{"foo", "bar"},
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false),
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -324,7 +326,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
ID: "1",
Expand All @@ -349,7 +351,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when forceUseGraphAPI is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
forceUseGraphAPI: true,
},
claims: &azureClaims{
Expand All @@ -376,7 +378,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no match",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -392,7 +394,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no role claims returned",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false),
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
Expand All @@ -416,7 +418,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
}

if tt.fields.SocialBase == nil {
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false)
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
}

key := []byte("secret")
Expand Down
9 changes: 9 additions & 0 deletions pkg/login/social/generic_oauth.go
Expand Up @@ -14,6 +14,8 @@ import (
"strconv"

"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

type SocialGenericOAuth struct {
Expand Down Expand Up @@ -504,3 +506,10 @@ func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string,

return logins, true
}

func (s *SocialGenericOAuth) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
opts = append(opts, oauth2.AccessTypeOffline)
}
return s.SocialBase.AuthCodeURL(state, opts...)
}
4 changes: 3 additions & 1 deletion pkg/login/social/github_oauth_test.go
Expand Up @@ -9,6 +9,8 @@ import (

"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

const testGHUserTeamsJSON = `[
Expand Down Expand Up @@ -202,7 +204,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {

s := &SocialGithub{
SocialBase: newSocialBase("github", &oauth2.Config{},
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false),
&OAuthInfo{RoleAttributePath: tt.roleAttributePath}, tt.autoAssignOrgRole, false, *featuremgmt.WithFeatures()),
allowedOrganizations: []string{},
apiUrl: server.URL + "/user",
teamIds: []int{},
Expand Down
9 changes: 9 additions & 0 deletions pkg/login/social/google_oauth.go
Expand Up @@ -6,6 +6,8 @@ import (
"net/http"

"golang.org/x/oauth2"

"github.com/grafana/grafana/pkg/services/featuremgmt"
)

type SocialGoogle struct {
Expand Down Expand Up @@ -38,3 +40,10 @@ func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
Login: data.Email,
}, nil
}

func (s *SocialGoogle) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if s.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
}
return s.SocialBase.AuthCodeURL(state, opts...)
}
21 changes: 12 additions & 9 deletions pkg/login/social/social.go
Expand Up @@ -16,6 +16,7 @@ import (
"golang.org/x/text/language"

"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
Expand Down Expand Up @@ -58,7 +59,7 @@ type OAuthInfo struct {
UsePKCE bool
}

func ProvideService(cfg *setting.Cfg) *SocialService {
func ProvideService(cfg *setting.Cfg, features *featuremgmt.FeatureManager) *SocialService {
ss := SocialService{
cfg: cfg,
oAuthProvider: make(map[string]*OAuthInfo),
Expand Down Expand Up @@ -139,7 +140,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// GitHub.
if name == "github" {
ss.socialMap["github"] = &SocialGithub{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
teamIds: sec.Key("team_ids").Ints(","),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
Expand All @@ -149,7 +150,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// GitLab.
if name == "gitlab" {
ss.socialMap["gitlab"] = &SocialGitlab{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
Expand All @@ -158,7 +159,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Google.
if name == "google" {
ss.socialMap["google"] = &SocialGoogle{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
hostedDomain: info.HostedDomain,
apiUrl: info.ApiUrl,
}
Expand All @@ -167,7 +168,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// AzureAD.
if name == "azuread" {
ss.socialMap["azuread"] = &SocialAzureAD{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
}
Expand All @@ -176,7 +177,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Okta
if name == "okta" {
ss.socialMap["okta"] = &SocialOkta{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
}
Expand All @@ -185,7 +186,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
// Generic - Uses the same scheme as GitHub.
if name == "generic_oauth" {
ss.socialMap["generic_oauth"] = &SocialGenericOAuth{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
apiUrl: info.ApiUrl,
teamsUrl: info.TeamsUrl,
emailAttributeName: info.EmailAttributeName,
Expand Down Expand Up @@ -214,8 +215,7 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
}

ss.socialMap[grafanaCom] = &SocialGrafanaCom{
SocialBase: newSocialBase(name, &config, info,
cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync),
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
url: cfg.GrafanaComURL,
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
}
Expand Down Expand Up @@ -261,6 +261,7 @@ type SocialBase struct {
roleAttributeStrict bool
autoAssignOrgRole string
skipOrgRoleSync bool
features featuremgmt.FeatureManager
}

type Error struct {
Expand Down Expand Up @@ -295,6 +296,7 @@ func newSocialBase(name string,
info *OAuthInfo,
autoAssignOrgRole string,
skipOrgRoleSync bool,
features featuremgmt.FeatureManager,
) *SocialBase {
logger := log.New("oauth." + name)

Expand All @@ -308,6 +310,7 @@ func newSocialBase(name string,
roleAttributePath: info.RoleAttributePath,
roleAttributeStrict: info.RoleAttributeStrict,
skipOrgRoleSync: skipOrgRoleSync,
features: features,
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server.go
Expand Up @@ -127,7 +127,7 @@ func (s *Server) init() error {
}

login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService)
social.ProvideService(s.cfg)
social.ProvideService(s.cfg, s.HTTPServer.Features)

if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil {
return err
Expand Down
23 changes: 15 additions & 8 deletions pkg/services/contexthandler/contexthandler.go
Expand Up @@ -449,20 +449,14 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
return false
}

getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}

if h.features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
if exists {
// Skip where the OAuthExpiry is default/zero/unset
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
if h.hasAccessTokenExpired(oauthToken) {
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))

// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and invalidate the OAuth tokens
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
Expand Down Expand Up @@ -732,3 +726,16 @@ func AuthHTTPHeaderListFromContext(c context.Context) *AuthHTTPHeaderList {
}
return nil
}

func (h *ContextHandler) hasAccessTokenExpired(token *models.UserAuth) bool {
if token.OAuthExpiry.IsZero() {
return false
}

getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}

return token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime())
}