diff --git a/README.md b/README.md index d9a2c52..ed07eaa 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,7 @@ These features can be configured by populating fields in the parsing JWTs. For an example, see the `examples/custom` directory. * The remote JWKS resource can be refreshed manually using the `.Refresh` method. This can bypass the rate limit, if the option is set. +* There is support for creating one `jwt.Keyfunc` from multiple JWK Sets through the use of the `keyfunc.GetMultiple`. ## Notes Trailing padding is required to be removed from base64url encoded keys inside a JWKS. This is because RFC 7517 defines diff --git a/keyfunc.go b/keyfunc.go index 967c5a9..1f082bd 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -16,23 +16,33 @@ var ( // Keyfunc matches the signature of github.com/golang-jwt/jwt/v4's jwt.Keyfunc function. func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { + kid, alg, err := kidAlg(token) + if err != nil { + return nil, err + } + return j.getKey(alg, kid) +} + +func (m *MultipleJWKS) Keyfunc(token *jwt.Token) (interface{}, error) { + return m.keySelector(m, token) +} + +func kidAlg(token *jwt.Token) (kid, alg string, err error) { kidInter, ok := token.Header["kid"] if !ok { - return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKID) + return "", "", fmt.Errorf("%w: could not find kid in JWT header", ErrKID) } - kid, ok := kidInter.(string) + kid, ok = kidInter.(string) if !ok { - return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) + return "", "", fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) } - - alg, ok := token.Header["alg"].(string) + alg, ok = token.Header["alg"].(string) if !ok { // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token // without an alg parameter in the header before calling jwt.Keyfunc. - return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch) + return "", "", fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch) } - - return j.getKey(alg, kid) + return kid, alg, nil } // base64urlTrailingPadding removes trailing padding before decoding a string from base64url. Some non-RFC compliant diff --git a/multiple.go b/multiple.go new file mode 100644 index 0000000..61ea30b --- /dev/null +++ b/multiple.go @@ -0,0 +1,69 @@ +package keyfunc + +import ( + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v4" +) + +// ErrMultipleJWKSSize is returned when the number of JWKS given are not enough to make a MultipleJWKS. +var ErrMultipleJWKSSize = errors.New("multiple JWKS must have two or more remote JWK Set resources") + +// MultipleJWKS manages multiple JWKS and has a field for jwt.Keyfunc. +type MultipleJWKS struct { + keySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) + sets map[string]*JWKS // No lock is required because this map is read-only after initialization. +} + +// GetMultiple creates a new MultipleJWKS. A map of length two or more JWKS URLs to Options is required. +// +// Be careful when choosing Options for each JWKS in the map. If RefreshUnknownKID is set to true for all JWKS in the +// map then many refresh requests would take place each time a JWT is processed, this should be rate limited by +// RefreshRateLimit. +func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) { + if multiple == nil || len(multiple) < 2 { + return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize) + } + + if options.KeySelector == nil { + options.KeySelector = KeySelectorFirst + } + + multiJWKS = &MultipleJWKS{ + sets: make(map[string]*JWKS, len(multiple)), + keySelector: options.KeySelector, + } + + for u, opts := range multiple { + jwks, err := Get(u, opts) + if err != nil { + return nil, fmt.Errorf("failed to get JWKS from %q: %w", u, err) + } + multiJWKS.sets[u] = jwks + } + + return multiJWKS, nil +} + +func (m *MultipleJWKS) JWKSets() map[string]*JWKS { + sets := make(map[string]*JWKS, len(m.sets)) + for u, jwks := range m.sets { + sets[u] = jwks + } + return sets +} + +func KeySelectorFirst(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) { + kid, alg, err := kidAlg(token) + if err != nil { + return nil, err + } + for _, jwks := range multiJWKS.sets { + key, err = jwks.getKey(alg, kid) + if err == nil { + return key, nil + } + } + return nil, fmt.Errorf("failed to find key ID in multiple JWKS: %w", ErrKIDNotFound) +} diff --git a/multiple_test.go b/multiple_test.go new file mode 100644 index 0000000..5762c8a --- /dev/null +++ b/multiple_test.go @@ -0,0 +1,67 @@ +package keyfunc_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang-jwt/jwt/v4" + + "github.com/MicahParks/keyfunc" +) + +const ( + jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKID","kty":"OKP","x":"1IlXuWBIkjYbAXm5Hk5mvsbPq0skO3-G_hX1Cw7CY-8"},{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` + jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` +) + +func TestMultipleJWKS(t *testing.T) { + server1 := createTestServer([]byte(jwks1)) + defer server1.Close() + + server2 := createTestServer([]byte(jwks2)) + defer server2.Close() + + const ( + collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA" + uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg" + ) + + m := map[string]keyfunc.Options{ + server1.URL: {}, + server2.URL: {}, + } + + multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{}) + if err != nil { + t.Fatalf("failed to get multiple JWKS: %v", err) + } + + token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse collision JWT: %v", err) + } + if !token.Valid { + t.Fatalf("collision JWT is invalid") + } + + token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse unique JWT: %v", err) + } + if !token.Valid { + t.Fatalf("unique JWT is invalid") + } + + sets := multiJWKS.JWKSets() + if len(sets) != 2 { + t.Fatalf("expected 2 JWKS, got %d", len(sets)) + } +} + +func createTestServer(body []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(body) + })) +} diff --git a/options.go b/options.go index d65752f..cc4cf5e 100644 --- a/options.go +++ b/options.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "time" + + "github.com/golang-jwt/jwt/v4" ) // ErrInvalidHTTPStatusCode indicates that the HTTP status code is invalid. @@ -70,6 +72,9 @@ type Options struct { // This is done through a background goroutine. Without specifying a RefreshInterval a malicious client could // self-sign X JWTs, send them to this service, then cause potentially high network usage proportional to X. Make // sure to call the JWKS.EndBackground method to end this goroutine when it's no longer needed. + // + // It is recommended this option is not used when in MultipleJWKS. This is because KID collisions SHOULD be uncommon + // meaning nearly any JWT SHOULD trigger a refresh for the number of JWKS in the MultipleJWKS minus one. RefreshUnknownKID bool // RequestFactory creates HTTP requests for the remote JWKS resource located at the given url. For example, an @@ -81,6 +86,16 @@ type Options struct { ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error) } +// MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS. +type MultipleOptions struct { + // KeySelector is a function that selects the key to use for a given token. It will be used in the implementation + // for jwt.Keyfunc. If implementing this custom selector extract the key ID and algorithm from the token's header. + // Use the key ID to select a token and confirm the key's algorithm before returning it. + // + // This value defaults to KeySelectorFirst. + KeySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) +} + // RefreshOptions are used to specify manual refresh behavior. type RefreshOptions struct { IgnoreRateLimit bool