Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple JWK Sets #78

Merged
merged 6 commits into from Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -172,6 +172,7 @@ These features can be configured by populating fields in the
parsing JWTs. For an example, see the `examples/custom` directory.
* The remote JWKS resource can be refreshed manually using the `.Refresh` method. This can bypass the rate limit, if the
option is set.
* There is support for creating one `jwt.Keyfunc` from multiple JWK Sets through the use of the `keyfunc.GetMultiple`.

## Notes
Trailing padding is required to be removed from base64url encoded keys inside a JWKS. This is because RFC 7517 defines
Expand Down
26 changes: 18 additions & 8 deletions keyfunc.go
Expand Up @@ -16,23 +16,33 @@ var (

// Keyfunc matches the signature of github.com/golang-jwt/jwt/v4's jwt.Keyfunc function.
func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
kid, alg, err := kidAlg(token)
if err != nil {
return nil, err
}
return j.getKey(alg, kid)
}

func (m *MultipleJWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
return m.keySelector(m, token)
}

func kidAlg(token *jwt.Token) (kid, alg string, err error) {
kidInter, ok := token.Header["kid"]
if !ok {
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKID)
return "", "", fmt.Errorf("%w: could not find kid in JWT header", ErrKID)
}
kid, ok := kidInter.(string)
kid, ok = kidInter.(string)
if !ok {
return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID)
return "", "", fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID)
}

alg, ok := token.Header["alg"].(string)
alg, ok = token.Header["alg"].(string)
if !ok {
// For test coverage purposes, this should be impossible to reach because the JWT package rejects a token
// without an alg parameter in the header before calling jwt.Keyfunc.
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch)
return "", "", fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch)
}

return j.getKey(alg, kid)
return kid, alg, nil
}

// base64urlTrailingPadding removes trailing padding before decoding a string from base64url. Some non-RFC compliant
Expand Down
69 changes: 69 additions & 0 deletions multiple.go
@@ -0,0 +1,69 @@
package keyfunc

import (
"errors"
"fmt"

"github.com/golang-jwt/jwt/v4"
)

// ErrMultipleJWKSSize is returned when the number of JWKS given are not enough to make a MultipleJWKS.
var ErrMultipleJWKSSize = errors.New("multiple JWKS must have two or more remote JWK Set resources")

// MultipleJWKS manages multiple JWKS and has a field for jwt.Keyfunc.
type MultipleJWKS struct {
keySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error)
sets map[string]*JWKS // No lock is required because this map is read-only after initialization.
}

// GetMultiple creates a new MultipleJWKS. A map of length two or more JWKS URLs to Options is required.
//
// Be careful when choosing Options for each JWKS in the map. If RefreshUnknownKID is set to true for all JWKS in the
// map then many refresh requests would take place each time a JWT is processed, this should be rate limited by
// RefreshRateLimit.
func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) {
if multiple == nil || len(multiple) < 2 {
return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize)
}

if options.KeySelector == nil {
options.KeySelector = KeySelectorFirst
}

multiJWKS = &MultipleJWKS{
sets: make(map[string]*JWKS, len(multiple)),
keySelector: options.KeySelector,
}

for u, opts := range multiple {
jwks, err := Get(u, opts)
if err != nil {
return nil, fmt.Errorf("failed to get JWKS from %q: %w", u, err)
}
multiJWKS.sets[u] = jwks
}

return multiJWKS, nil
}

func (m *MultipleJWKS) JWKSets() map[string]*JWKS {
sets := make(map[string]*JWKS, len(m.sets))
for u, jwks := range m.sets {
sets[u] = jwks
}
return sets
}

func KeySelectorFirst(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) {
kid, alg, err := kidAlg(token)
if err != nil {
return nil, err
}
for _, jwks := range multiJWKS.sets {
key, err = jwks.getKey(alg, kid)
if err == nil {
return key, nil
}
}
return nil, fmt.Errorf("failed to find key ID in multiple JWKS: %w", ErrKIDNotFound)
}
67 changes: 67 additions & 0 deletions multiple_test.go
@@ -0,0 +1,67 @@
package keyfunc_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v4"

"github.com/MicahParks/keyfunc"
)

const (
jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKID","kty":"OKP","x":"1IlXuWBIkjYbAXm5Hk5mvsbPq0skO3-G_hX1Cw7CY-8"},{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}`
jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}`
)

func TestMultipleJWKS(t *testing.T) {
server1 := createTestServer([]byte(jwks1))
defer server1.Close()

server2 := createTestServer([]byte(jwks2))
defer server2.Close()

const (
collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA"
uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg"
)

m := map[string]keyfunc.Options{
server1.URL: {},
server2.URL: {},
}

multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{})
if err != nil {
t.Fatalf("failed to get multiple JWKS: %v", err)
}

token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc)
if err != nil {
t.Fatalf("failed to parse collision JWT: %v", err)
}
if !token.Valid {
t.Fatalf("collision JWT is invalid")
}

token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc)
if err != nil {
t.Fatalf("failed to parse unique JWT: %v", err)
}
if !token.Valid {
t.Fatalf("unique JWT is invalid")
}

sets := multiJWKS.JWKSets()
if len(sets) != 2 {
t.Fatalf("expected 2 JWKS, got %d", len(sets))
}
}

func createTestServer(body []byte) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(body)
}))
}
15 changes: 15 additions & 0 deletions options.go
Expand Up @@ -8,6 +8,8 @@ import (
"io"
"net/http"
"time"

"github.com/golang-jwt/jwt/v4"
)

// ErrInvalidHTTPStatusCode indicates that the HTTP status code is invalid.
Expand Down Expand Up @@ -70,6 +72,9 @@ type Options struct {
// This is done through a background goroutine. Without specifying a RefreshInterval a malicious client could
// self-sign X JWTs, send them to this service, then cause potentially high network usage proportional to X. Make
// sure to call the JWKS.EndBackground method to end this goroutine when it's no longer needed.
//
// It is recommended this option is not used when in MultipleJWKS. This is because KID collisions SHOULD be uncommon
// meaning nearly any JWT SHOULD trigger a refresh for the number of JWKS in the MultipleJWKS minus one.
RefreshUnknownKID bool

// RequestFactory creates HTTP requests for the remote JWKS resource located at the given url. For example, an
Expand All @@ -81,6 +86,16 @@ type Options struct {
ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}

// MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS.
type MultipleOptions struct {
// KeySelector is a function that selects the key to use for a given token. It will be used in the implementation
// for jwt.Keyfunc. If implementing this custom selector extract the key ID and algorithm from the token's header.
// Use the key ID to select a token and confirm the key's algorithm before returning it.
//
// This value defaults to KeySelectorFirst.
KeySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error)
}

// RefreshOptions are used to specify manual refresh behavior.
type RefreshOptions struct {
IgnoreRateLimit bool
Expand Down