Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth: support AWS ALB JWT #45191

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions conf/defaults.ini
Expand Up @@ -564,6 +564,7 @@ jwk_set_file =
cache_ttl = 60m
expected_claims = {}
key_file =
key_url =
auto_sign_up = false

#################################### Auth LDAP ###########################
Expand Down
48 changes: 48 additions & 0 deletions pkg/services/auth/jwt/auth_test.go
Expand Up @@ -2,6 +2,7 @@ package jwt

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
Expand Down Expand Up @@ -177,6 +178,29 @@ func TestCachingJWKHTTPResponse(t *testing.T) {
})
}

func TestVerifyUsingKeyURL(t *testing.T) {
t.Run("should refuse to start with non-https URL", func(t *testing.T) {
var err error

_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
cfg.JWTAuthKeyURL = "https://example.com/{{.kid}}"
})
require.NoError(t, err)

_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
cfg.JWTAuthKeyURL = "http://example.com/{{.kid}}"
})
require.Error(t, err)
})

keyURLScenario(t, "verifies a token signed with a key from url", func(t *testing.T, sc scenarioContext) {
token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject})
verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
require.NoError(t, err)
assert.Equal(t, verifiedClaims["sub"], subject)
})
}

func TestSignatureWithNoneAlgorithm(t *testing.T) {
scenario(t, "rejects a token signed with \"none\" algorithm", func(t *testing.T, sc scenarioContext) {
token := signNone(t, jwt.Claims{Subject: "foo"})
Expand Down Expand Up @@ -345,6 +369,30 @@ func jwkCachingScenario(t *testing.T, desc string, fn cachingScenarioFunc, cbs .
})
}

func keyURLScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...configureFunc) {
t.Helper()
t.Run(desc, func(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rsaPublicKey := rsaKeys[0].Public()
derRsaPublicKey := x509.MarshalPKCS1PublicKey(rsaPublicKey.(*rsa.PublicKey))
if err := pem.Encode(w, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: derRsaPublicKey}); err != nil {
panic(err)
}
}))
t.Cleanup(ts.Close)

configure := func(t *testing.T, cfg *setting.Cfg) {
cfg.JWTAuthKeyURL = ts.URL
}
runner := scenarioRunner(func(t *testing.T, sc scenarioContext) {
keySet := sc.authJWTSvc.keySet.(*keySetHTTPKey)
keySet.client = ts.Client()
fn(t, sc)
}, append([]configureFunc{configure}, cbs...)...)
runner(t)
})
}

func TestBase64Paddings(t *testing.T) {
key := rsaKeys[0]

Expand Down
130 changes: 130 additions & 0 deletions pkg/services/auth/jwt/key_sets.go
Expand Up @@ -8,11 +8,13 @@ import (
"encoding/pem"
"errors"
"fmt"
"html/template"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/grafana/grafana/pkg/infra/log"
Expand Down Expand Up @@ -42,11 +44,23 @@ type keySetHTTP struct {
cacheExpiration time.Duration
}

type keySetHTTPKey struct {
url string
log log.Logger
client *http.Client
cache *remotecache.RemoteCache
cacheKey string
cacheExpiration time.Duration
}

func (s *AuthService) checkKeySetConfiguration() error {
var count int
if s.Cfg.JWTAuthKeyFile != "" {
count++
}
if s.Cfg.JWTAuthKeyURL != "" {
count++
}
if s.Cfg.JWTAuthJWKSetFile != "" {
count++
}
Expand Down Expand Up @@ -123,6 +137,22 @@ func (s *AuthService) initKeySet() error {
Keys: []jose.JSONWebKey{{Key: key}},
},
}
} else if urlStr := s.Cfg.JWTAuthKeyURL; urlStr != "" {
urlParsed, err := url.Parse(urlStr)
if err != nil {
return err
}
if urlParsed.Scheme != "https" {
return ErrJWTSetURLMustHaveHTTPSScheme
}
s.keySet = &keySetHTTPKey{
url: urlStr,
log: s.log,
client: &http.Client{},
cacheKey: fmt.Sprintf("auth-jwt:key-%s", urlStr),
cacheExpiration: s.Cfg.JWTAuthCacheTTL,
cache: s.RemoteCache,
}
} else if keyFilePath := s.Cfg.JWTAuthJWKSetFile; keyFilePath != "" {
// nolint:gosec
// We can ignore the gosec G304 warning on this one because `fileName` comes from grafana configuration file
Expand Down Expand Up @@ -212,3 +242,103 @@ func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, er
}
return jwks.Key(ctx, kid)
}

func (ks *keySetHTTPKey) getKey(ctx context.Context, kid string) (jose.JSONWebKey, error) {
var key jose.JSONWebKey

url := ks.url
cacheKey := ks.cacheKey
if strings.Index(url, "{{.kid}}") >= 0 {
tmpl, err := template.New("").Parse(url)
if err != nil {
return key, err
}
m := map[string]interface{}{
"kid": kid,
}
w := new(strings.Builder)
tmpl.Execute(w, m)

url = w.String()
cacheKey += "-" + kid
}

if ks.cacheExpiration > 0 {
if val, err := ks.cache.Get(ctx, cacheKey); err == nil {
return ks.decode(ctx, kid, val.([]byte))
}
}

ks.log.Debug("Getting key from endpoint", "url", url)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return key, err
}

resp, err := ks.client.Do(req)
if err != nil {
return key, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
ks.log.Warn("Failed to close response body", "err", err)
}
}()

data, err := io.ReadAll(resp.Body)
if err != nil {
return key, err
}

key, err = ks.decode(ctx, kid, data)
if err != nil {
return key, err
}

if ks.cacheExpiration > 0 {
err = ks.cache.Set(ctx, cacheKey, data, ks.cacheExpiration)
}
return key, err
}

func (ks keySetHTTPKey) decode(ctx context.Context, kid string, data []byte) (jose.JSONWebKey, error) {
var key jose.JSONWebKey

block, _ := pem.Decode(data)
if block == nil {
return key, ErrFailedToParsePemFile
}

var pubkey interface{}
var err error
switch block.Type {
case "PUBLIC KEY":
if pubkey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
return key, err
}
case "RSA PUBLIC KEY":
if pubkey, err = x509.ParsePKCS1PublicKey(block.Bytes); err != nil {
return key, err
}
default:
return key, fmt.Errorf("unknown pem block type %q", block.Type)
}

key = jose.JSONWebKey{
Key: pubkey,
}

return key, nil
}

func (ks keySetHTTPKey) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) {
var keys []jose.JSONWebKey

key, err := ks.getKey(ctx, kid)
if err != nil {
return keys, err
}
keys = append(keys, key)
return keys, nil
}
2 changes: 2 additions & 0 deletions pkg/setting/setting.go
Expand Up @@ -318,6 +318,7 @@ type Cfg struct {
JWTAuthJWKSetURL string
JWTAuthCacheTTL time.Duration
JWTAuthKeyFile string
JWTAuthKeyURL string
JWTAuthJWKSetFile string
JWTAuthAutoSignUp bool

Expand Down Expand Up @@ -1275,6 +1276,7 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) {
cfg.JWTAuthJWKSetURL = valueAsString(authJWT, "jwk_set_url", "")
cfg.JWTAuthCacheTTL = authJWT.Key("cache_ttl").MustDuration(time.Minute * 60)
cfg.JWTAuthKeyFile = valueAsString(authJWT, "key_file", "")
cfg.JWTAuthKeyURL = valueAsString(authJWT, "key_url", "")
cfg.JWTAuthJWKSetFile = valueAsString(authJWT, "jwk_set_file", "")
cfg.JWTAuthAutoSignUp = authJWT.Key("auto_sign_up").MustBool(false)

Expand Down