From 81c7ee27994cea901355c5eb8d270cc46c459f88 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Sun, 18 Dec 2022 12:10:05 -0500 Subject: [PATCH 1/5] Add support for multiple JWKS --- keyfunc.go | 27 ++++++++++++++------- multiple.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++ options.go | 12 ++++++++++ 3 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 multiple.go diff --git a/keyfunc.go b/keyfunc.go index 967c5a9..51f5a6e 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -16,23 +16,34 @@ var ( // Keyfunc matches the signature of github.com/golang-jwt/jwt/v4's jwt.Keyfunc function. func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { + kid, alg, err := kidAlg(token) + if err != nil { + return nil, err + } + + return j.getKey(alg, kid) +} + +func (m *MultipleJWKS) Keyfunc(token *jwt.Token) (interface{}, error) { + return m.keySelector(m, token) +} + +func kidAlg(token *jwt.Token) (kid, alg string, err error) { kidInter, ok := token.Header["kid"] if !ok { - return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKID) + return "", "", fmt.Errorf("%w: could not find kid in JWT header", ErrKID) } - kid, ok := kidInter.(string) + kid, ok = kidInter.(string) if !ok { - return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) + return "", "", fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) } - - alg, ok := token.Header["alg"].(string) + alg, ok = token.Header["alg"].(string) if !ok { // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token // without an alg parameter in the header before calling jwt.Keyfunc. - return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch) + return "", "", fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch) } - - return j.getKey(alg, kid) + return kid, alg, nil } // base64urlTrailingPadding removes trailing padding before decoding a string from base64url. Some non-RFC compliant diff --git a/multiple.go b/multiple.go new file mode 100644 index 0000000..e055ffe --- /dev/null +++ b/multiple.go @@ -0,0 +1,68 @@ +package keyfunc + +import ( + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v4" +) + +// ErrMultipleJWKSSize is returned when the number of JWKS given are not enough to make a MultipleJWKS. +var ErrMultipleJWKSSize = errors.New("multiple JWKS must have two or more remote JWK Set resources") + +// MultipleJWKS manages multiple JWKS and has a field for jwt.Keyfunc. +type MultipleJWKS struct { + keySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) + sets map[string]*JWKS // No lock is required because this map is read-only after initialization. +} + +// GetMultiple creates a new MultipleJWKS. A map of length two or more JWKS URLs to Options is required. +// +// Be careful when choosing Options for each JWKS in the map. If RefreshUnknownKID is set to true for all JWKS in the +// map then many refresh requests would take place each time a JWT is processed, this should be rate limited by +// RefreshRateLimit. +func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) { + if multiple == nil || len(multiple) < 2 { + return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize) + } + + if options.KeySelector == nil { + options.KeySelector = KeySelectorFirst + } + + multiJWKS = &MultipleJWKS{ + sets: make(map[string]*JWKS, len(multiple)), + } + + for u, opts := range multiple { + jwks, err := Get(u, opts) + if err != nil { + return nil, fmt.Errorf("failed to get JWKS from %q: %w", u, err) + } + multiJWKS.sets[u] = jwks + } + + return multiJWKS, nil +} + +func (m *MultipleJWKS) JWKSets() map[string]*JWKS { + sets := make(map[string]*JWKS, len(m.sets)) + for u, jwks := range m.sets { + sets[u] = jwks + } + return sets +} + +func KeySelectorFirst(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) { + kid, alg, err := kidAlg(token) + if err != nil { + return nil, err + } + for _, jwks := range multiJWKS.sets { + key, err = jwks.getKey(alg, kid) + if err == nil { + return key, nil + } + } + return nil, fmt.Errorf("failed to find key ID in multiple JWKS: %w", ErrKIDNotFound) +} diff --git a/options.go b/options.go index a4b77a7..a716df2 100644 --- a/options.go +++ b/options.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "time" + + "github.com/golang-jwt/jwt/v4" ) // ErrInvalidHTTPStatusCode indicates that the HTTP status code is invalid. @@ -81,6 +83,16 @@ type Options struct { ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error) } +// MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS. +type MultipleOptions struct { + // KeySelector is a function that selects the key to use for a given token. It will be used in the implementation + // for jwt.Keyfunc. If implementing this custom selector extract the key ID and algorithm from the token's header. + // Use the key ID to select a token and confirm the key's algorithm before returning it. + // + // This value defaults to KeySelectorFirst. + KeySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) +} + // ResponseExtractorStatusOK is meant to be used as the ResponseExtractor field for Options. It confirms that response // status code is 200 OK and returns the raw JSON from the response body. func ResponseExtractorStatusOK(ctx context.Context, resp *http.Response) (json.RawMessage, error) { From 040769c41b54413c526b0ef60b2c88a65a444b85 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Sun, 18 Dec 2022 21:25:52 -0500 Subject: [PATCH 2/5] Start on tests for multiple JWKS --- keyfunc.go | 1 - multiple_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 multiple_test.go diff --git a/keyfunc.go b/keyfunc.go index 51f5a6e..1f082bd 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -20,7 +20,6 @@ func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { if err != nil { return nil, err } - return j.getKey(alg, kid) } diff --git a/multiple_test.go b/multiple_test.go new file mode 100644 index 0000000..3a1426d --- /dev/null +++ b/multiple_test.go @@ -0,0 +1,40 @@ +package keyfunc_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/MicahParks/keyfunc" +) + +const ( + jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"keyCollision","kty":"OKP","x":"w_a0ZgEjNuD_YrNtexSfVcZkJKzzRmf4Jv7gDmRkTj0"}]}` + jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"keyCollision","kty":"OKP","x":"hDLmETJ2XnYMhVCrXjr0yv76ytPWZN3QiwSvStOYhj0"},{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKey","kty":"OKP","x":"hDLmETJ2XnYMhVCrXjr0yv76ytPWZN3QiwSvStOYhj0"}]}` +) + +func TestMultipleJWKS(t *testing.T) { + server1 := createTestServer([]byte(jwks1)) + defer server1.Close() + + server2 := createTestServer([]byte(jwks2)) + defer server2.Close() + + m := map[string]keyfunc.Options{ + server1.URL: {}, + server2.URL: {}, + } + + multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{}) + if err != nil { + t.Fatalf("failed to get multiple JWKS: %v", err) + } + +} + +func createTestServer(body []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(body) + })) +} From 646644e727050908fc3fa0db9b12df8981a08b30 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Sun, 18 Dec 2022 21:49:11 -0500 Subject: [PATCH 3/5] Ass tests for multiple JWK Sets --- multiple.go | 3 ++- multiple_test.go | 31 +++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/multiple.go b/multiple.go index e055ffe..61ea30b 100644 --- a/multiple.go +++ b/multiple.go @@ -31,7 +31,8 @@ func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWK } multiJWKS = &MultipleJWKS{ - sets: make(map[string]*JWKS, len(multiple)), + sets: make(map[string]*JWKS, len(multiple)), + keySelector: options.KeySelector, } for u, opts := range multiple { diff --git a/multiple_test.go b/multiple_test.go index 3a1426d..5762c8a 100644 --- a/multiple_test.go +++ b/multiple_test.go @@ -5,12 +5,14 @@ import ( "net/http/httptest" "testing" + "github.com/golang-jwt/jwt/v4" + "github.com/MicahParks/keyfunc" ) const ( - jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"keyCollision","kty":"OKP","x":"w_a0ZgEjNuD_YrNtexSfVcZkJKzzRmf4Jv7gDmRkTj0"}]}` - jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"keyCollision","kty":"OKP","x":"hDLmETJ2XnYMhVCrXjr0yv76ytPWZN3QiwSvStOYhj0"},{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKey","kty":"OKP","x":"hDLmETJ2XnYMhVCrXjr0yv76ytPWZN3QiwSvStOYhj0"}]}` + jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKID","kty":"OKP","x":"1IlXuWBIkjYbAXm5Hk5mvsbPq0skO3-G_hX1Cw7CY-8"},{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` + jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` ) func TestMultipleJWKS(t *testing.T) { @@ -20,6 +22,11 @@ func TestMultipleJWKS(t *testing.T) { server2 := createTestServer([]byte(jwks2)) defer server2.Close() + const ( + collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA" + uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg" + ) + m := map[string]keyfunc.Options{ server1.URL: {}, server2.URL: {}, @@ -30,6 +37,26 @@ func TestMultipleJWKS(t *testing.T) { t.Fatalf("failed to get multiple JWKS: %v", err) } + token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse collision JWT: %v", err) + } + if !token.Valid { + t.Fatalf("collision JWT is invalid") + } + + token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse unique JWT: %v", err) + } + if !token.Valid { + t.Fatalf("unique JWT is invalid") + } + + sets := multiJWKS.JWKSets() + if len(sets) != 2 { + t.Fatalf("expected 2 JWKS, got %d", len(sets)) + } } func createTestServer(body []byte) *httptest.Server { From 6739ca513b52607d40b953386bd8f982e2aec2f2 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Sun, 18 Dec 2022 21:53:05 -0500 Subject: [PATCH 4/5] Add note in README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 4a95ebe..510757b 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,7 @@ These features can be configured by populating fields in the * Custom cryptographic algorithms can be used. Make sure to use [`jwt.RegisterSigningMethod`](https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod) before parsing JWTs. For an example, see the `examples/custom` directory. +* There is support for creating one `jwt.Keyfunc` from multiple JWK Sets through the use of the `keyfunc.GetMultiple`. ## Notes Trailing padding is required to be removed from base64url encoded keys inside a JWKS. This is because RFC 7517 defines From eaceb569573b024fa6eba6e6e07f29704f08b93e Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Tue, 20 Dec 2022 21:49:04 -0500 Subject: [PATCH 5/5] Add comment for RefreshUnknownKID option --- options.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/options.go b/options.go index dac87cd..cc4cf5e 100644 --- a/options.go +++ b/options.go @@ -72,6 +72,9 @@ type Options struct { // This is done through a background goroutine. Without specifying a RefreshInterval a malicious client could // self-sign X JWTs, send them to this service, then cause potentially high network usage proportional to X. Make // sure to call the JWKS.EndBackground method to end this goroutine when it's no longer needed. + // + // It is recommended this option is not used when in MultipleJWKS. This is because KID collisions SHOULD be uncommon + // meaning nearly any JWT SHOULD trigger a refresh for the number of JWKS in the MultipleJWKS minus one. RefreshUnknownKID bool // RequestFactory creates HTTP requests for the remote JWKS resource located at the given url. For example, an