From cb6a4c0b725c3569b6a6adf4f549c86cafec3bea Mon Sep 17 00:00:00 2001 From: Anders Eknert Date: Thu, 2 Jun 2022 12:07:57 +0200 Subject: [PATCH] Ignore keys of unknown alg when verifying JWTs with JWKS (#4725) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Additionally, improve the JWT verification process when a JWKS is provided: * If `kid` is present in JWT header, and exists in JWKS — verify using that key only. * If `kid` not present in JWT header, try verification only using keys matching the `alg` provided in the JWT header (mandatory claim). Fixes #4699 Signed-off-by: Anders Eknert --- internal/jwx/jwa/signature.go | 4 +- internal/jwx/jwk/jwk.go | 3 + internal/jwx/jws/headers_test.go | 8 +- internal/jwx/jws/jws_test.go | 7 +- topdown/tokens.go | 117 ++++++++++++++--- topdown/tokens_test.go | 209 ++++++++++++++++++++++++++++++- 6 files changed, 319 insertions(+), 29 deletions(-) diff --git a/internal/jwx/jwa/signature.go b/internal/jwx/jwa/signature.go index 4329408db4..45e400176d 100644 --- a/internal/jwx/jwa/signature.go +++ b/internal/jwx/jwa/signature.go @@ -27,6 +27,7 @@ const ( RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384 RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512 NoValue SignatureAlgorithm = "" // No value is different from none + Unsupported SignatureAlgorithm = "unsupported" ) // Accept is used when conversion from values given by @@ -69,7 +70,8 @@ func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error { } _, ok := signatureAlg[quoted] if !ok { - return errors.New("unknown signature algorithm") + *signature = Unsupported + return nil } *signature = SignatureAlgorithm(quoted) return nil diff --git a/internal/jwx/jwk/jwk.go b/internal/jwx/jwk/jwk.go index d68c814c65..aa22a3830f 100644 --- a/internal/jwx/jwk/jwk.go +++ b/internal/jwx/jwk/jwk.go @@ -99,6 +99,9 @@ func parse(jwkSrc string) (*Set, error) { } else { for i := range rawKeySetJSON.Keys { rawKeyJSON := rawKeySetJSON.Keys[i] + if rawKeyJSON.Algorithm != nil && *rawKeyJSON.Algorithm == jwa.Unsupported { + continue + } jwkKey, err = rawKeyJSON.GenerateKey() if err != nil { return nil, fmt.Errorf("failed to generate key: %w", err) diff --git a/internal/jwx/jws/headers_test.go b/internal/jwx/jws/headers_test.go index cde814fcbe..7327972f03 100644 --- a/internal/jwx/jws/headers_test.go +++ b/internal/jwx/jws/headers_test.go @@ -112,9 +112,11 @@ func TestHeader(t *testing.T) { headers := `{"typ":"JWT",` + "\r\n" + ` "alg":"dummy"}` var standardHeaders jws.StandardHeaders err := json.Unmarshal([]byte(headers), &standardHeaders) - if err == nil { - t.Fatal("Unmarshal should have failed") + if err != nil { + t.Fatal(err) + } + if standardHeaders.Algorithm != jwa.Unsupported { + t.Errorf("expected unsupported algorithm") } - }) } diff --git a/internal/jwx/jws/jws_test.go b/internal/jwx/jws/jws_test.go index 815210db86..3a694d4814 100644 --- a/internal/jwx/jws/jws_test.go +++ b/internal/jwx/jws/jws_test.go @@ -82,8 +82,11 @@ func TestAlgError(t *testing.T) { const hdr = `{"typ":"JWT",` + "\r\n" + ` "alg":"unknown"}` var standardHeaders jws.StandardHeaders err := json.Unmarshal([]byte(hdr), &standardHeaders) - if err == nil { - t.Fatal("header parsing should have failed") + if err != nil { + t.Fatal(err) + } + if standardHeaders.Algorithm != jwa.Unsupported { + t.Errorf("expected unsupported algorithm") } }) } diff --git a/topdown/tokens.go b/topdown/tokens.go index 394dcfbada..c23fa1b254 100644 --- a/topdown/tokens.go +++ b/topdown/tokens.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/internal/jwx/jwa" "github.com/open-policy-agent/opa/internal/jwx/jwk" "github.com/open-policy-agent/opa/internal/jwx/jws" "github.com/open-policy-agent/opa/topdown/builtins" @@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error { return fmt.Errorf("ECDSA signature verification error") } -// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s). +type verificationKey struct { + alg string + kid string + key interface{} +} + +// getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s). // A valid PEM block is never valid JSON (and vice versa), hence can try parsing both. -func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) { +// When provided a JWKS, each key additionally likely contains a key ID and the key algorithm. +func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) { if block, rest := pem.Decode([]byte(certificate)); block != nil { if len(rest) > 0 { return nil, fmt.Errorf("extra data after a PEM certificate block") @@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) { if err != nil { return nil, fmt.Errorf("failed to parse a PEM certificate: %w", err) } - - return []interface{}{cert.PublicKey}, nil + return []verificationKey{{key: cert.PublicKey}}, nil } if block.Type == "PUBLIC KEY" { @@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) { return nil, fmt.Errorf("failed to parse a PEM public key: %w", err) } - return []interface{}{key}, nil + return []verificationKey{{key: key}}, nil } return nil, fmt.Errorf("failed to extract a Key from the PEM certificate") @@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) { return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err) } - var keys []interface{} + var keys []verificationKey for _, k := range jwks.Keys { key, err := k.Materialize() if err != nil { return nil, err } - keys = append(keys, key) + keys = append(keys, verificationKey{ + alg: k.GetAlgorithm().String(), + kid: k.GetKeyID(), + key: key, + }) } return keys, nil } +func getKeyByKid(kid string, keys []verificationKey) *verificationKey { + for _, key := range keys { + if key.kid == kid { + return &key + } + } + return nil +} + // Implements JWT signature verification. func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) { token, err := decodeJWT(a) @@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify return nil, err } - keys, err := getKeyFromCertOrJWK(string(s)) + keys, err := getKeysFromCertOrJWK(string(s)) if err != nil { return nil, err } @@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify return nil, err } + err = token.decodeHeader() + if err != nil { + return nil, err + } + header, err := parseTokenHeader(token) + if err != nil { + return nil, err + } + // Validate the JWT signature - for _, key := range keys { - err = verify(key, - getInputSHA([]byte(token.header+"."+token.payload), hasher), - []byte(signature)) - if err == nil { - return ast.Boolean(true), nil + // First, check if there's a matching key ID (`kid`) in both token header and key(s). + // If a match is found, verify using only that key. Only applicable when a JWKS was provided. + if header.kid != "" { + if key := getKeyByKid(header.kid, keys); key != nil { + err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature)) + + return ast.Boolean(err == nil), nil + } + } + + // If no key ID matched, try to verify using any key in the set + // If an alg is present in both the JWT header and the key, skip verification unless they match + for _, key := range keys { + if key.alg == "" { + // No algorithm provided for the key - this is likely a certificate and not a JWKS, so + // we'll need to verify to find out + err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature)) + if err == nil { + return ast.Boolean(true), nil + } + } else { + if header.alg != key.alg { + continue + } + err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature)) + if err == nil { + return ast.Boolean(true), nil + } } } @@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast // tokenConstraints holds decoded JWT verification constraints. type tokenConstraints struct { // The set of asymmetric keys we can verify with. - keys []interface{} + keys []verificationKey // The single symmetric key we will verify with. secret string @@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error { return fmt.Errorf("cert constraint: must be a string") } - keys, err := getKeyFromCertOrJWK(string(s)) + keys, err := getKeysFromCertOrJWK(string(s)) if err != nil { return err } + constraints.keys = keys return nil } @@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature } // If we're configured with asymmetric key(s) then only trust that if constraints.keys != nil { + if kid != "" { + if key := getKeyByKid(kid, constraints.keys); key != nil { + err := a.verify(key.key, a.hash, plaintext, []byte(signature)) + if err != nil { + return errSignatureNotVerified + } + return nil + } + } + verified := false for _, key := range constraints.keys { - err := a.verify(key, a.hash, plaintext, []byte(signature)) - if err == nil { - verified = true - break + if key.alg == "" { + err := a.verify(key.key, a.hash, plaintext, []byte(signature)) + if err == nil { + verified = true + break + } + } else { + if alg != key.alg { + continue + } + err := a.verify(key.key, a.hash, plaintext, []byte(signature)) + if err == nil { + verified = true + break + } } } + if !verified { return errSignatureNotVerified } @@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j return err } alg := standardHeaders.GetAlgorithm() + if alg == jwa.Unsupported { + return fmt.Errorf("unknown signature algorithm") + } if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) { return fmt.Errorf("type is JWT but payload is not JSON") diff --git a/topdown/tokens_test.go b/topdown/tokens_test.go index ac60c182dd..eeeeba3365 100644 --- a/topdown/tokens_test.go +++ b/topdown/tokens_test.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" + "crypto/sha256" "encoding/base64" "encoding/json" "fmt" @@ -63,7 +64,7 @@ OHoCIHmNX37JOqTcTzGn2u9+c8NlnvZ0uDvsd1BmKPaUmjmm if err != nil { t.Fatalf("parseTokenConstraints: %v", err) } - pubKey := constraints.keys[0].(*ecdsa.PublicKey) + pubKey := constraints.keys[0].key.(*ecdsa.PublicKey) if pubKey.Curve != elliptic.P256() { t.Errorf("curve: %v", pubKey.Curve) } @@ -102,12 +103,12 @@ OHoCIHmNX37JOqTcTzGn2u9+c8NlnvZ0uDvsd1BmKPaUmjmm if err != nil { t.Fatalf("parseTokenConstraints: %v", err) } - elPubKey := constraints.keys[0].(*ecdsa.PublicKey) + elPubKey := constraints.keys[0].key.(*ecdsa.PublicKey) if elPubKey.Curve != elliptic.P256() { t.Errorf("curve: %v", elPubKey.Curve) } - rsaPubKey := constraints.keys[1].(*rsa.PublicKey) + rsaPubKey := constraints.keys[1].key.(*rsa.PublicKey) if rsaPubKey.Size() != 256 { t.Errorf("expected size 256 found %d", rsaPubKey.Size()) } @@ -496,3 +497,205 @@ func TestTopdownJWTEncodeSignECWithSeedReturnsSameSignature(t *testing.T) { } } } + +func TestTopdownJWTUnknownAlgTypesDiscardedFromJWKS(t *testing.T) { + cert := `{ + "keys": [ + { + "kty": "RSA", + "e": "AQAB", + "use": "enc", + "kid": "k3", + "alg": "RS256", + "n": "sGu-fYVE2nq2dPxJlqAMI0Z8G3FD0XcWDnD8mkfO1ddKRGuUQZmfj4gWeZGyIk3cnuoy7KJCEqa3daXc08QHuFZyfn0rH33t8_AFsvb0q0i7R2FK-Gdqs_E0-sGpYMsRJdZWfCioLkYjIHEuVnRbi3DEsWqe484rEGbKF60jNRgGC4b-8pz-E538ZkssWxcqHrYIj5bjGEU36onjS3M_yrTuNvzv_8wRioK4fbcwmGne9bDxu8LcoSReWpPn0CnUkWnfqroRcMJnC87ZuJagDW1ZWCmU3psdsVanmFFh0DP6z0fsA4h8G2n9-qp-LEKFaWwo3IWlOsIzU3MHdcEiGw" + }, + { + "kid": "encryption algorithm", + "kty": "RSA", + "alg": "RSA-OAEP", + "use": "enc", + "n": "onlqv4UZx5ZabJ3TCq-IO0s0xaOwo6fWl9o4SzLXPbGtvxonQhoYOeMlS0XkdEdLzB-eqh_hkQ", + "e": "AQAB", + "x5c": [ + "MIICnTCCAYUCBgGAmcG0xjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQ2YVaQn47Eew==" + ], + "x5t": "WKfdwdQkg", + "x5t#S256": "2_FidAwjlCQl20" + } + ] +} +` + keys, err := getKeysFromCertOrJWK(cert) + if err != nil { + t.Fatal(err) + } + + if len(keys) != 1 { + t.Errorf("expected only one key as inavlid one should have been discarded") + } + + if keys[0].alg != "RS256" { + t.Errorf("expected key with RS256 alg") + } +} + +func TestTopdownJWTVerifyOnlyVerifiesUsingApplicableKeys(t *testing.T) { + cert := ast.MustInterfaceToValue(`{ + "keys": [ + { + "kty": "EC", + "use": "sig", + "crv": "P-256", + "kid": "k1", + "x": "9Qq5S5VqMQoH-FOI4atcH6V3bua03C-5ZMZMG1rszwA", + "y": "LLbFxWkGBEBrTm1GMYZJy1OXCH1KLweJMCgIEPIsibU", + "alg": "ES256" + }, + { + "kty": "RSA", + "e": "AQAB", + "use": "enc", + "kid": "k2", + "alg": "RS256", + "n": "sGu-fYVE2nq2dPxJlqAMI0Z8G3FD0XcWDnD8mkfO1ddKRGuUQZmfj4gWeZGyIk3cnuoy7KJCEqa3daXc08QHuFZyfn0rH33t8_AFsvb0q0i7R2FK-Gdqs_E0-sGpYMsRJdZWfCioLkYjIHEuVnRbi3DEsWqe484rEGbKF60jNRgGC4b-8pz-E538ZkssWxcqHrYIj5bjGEU36onjS3M_yrTuNvzv_8wRioK4fbcwmGne9bDxu8LcoSReWpPn0CnUkWnfqroRcMJnC87ZuJagDW1ZWCmU3psdsVanmFFh0DP6z0fsA4h8G2n9-qp-LEKFaWwo3IWlOsIzU3MHdcEiGw" + }, + { + "kty": "RSA", + "e": "AQAB", + "use": "enc", + "kid": "k3", + "alg": "RS256", + "n": "sGu-fYVE2nq2dPxJlqAMI0Z8G3FD0XcWDnD8mkfO1ddKRGuUQZmfj4gWeZGyIk3cnuoy7KJCEqa3daXc08QHuFZyfn0rH33t8_AFsvb0q0i7R2FK-Gdqs_E0-sGpYMsRJdZWfCioLkYjIHEuVnRbi3DEsWqe484rEGbKF60jNRgGC4b-8pz-E538ZkssWxcqHrYIj5bjGEU36onjS3M_yrTuNvzv_8wRioK4fbcwmGne9bDxu8LcoSReWpPn0CnUkWnfqroRcMJnC87ZuJagDW1ZWCmU3psdsVanmFFh0DP6z0fsA4h8G2n9-qp-LEKFaWwo3IWlOsIzU3MHdcEiGw" + }, + { + "kid": "unknown algorithm", + "kty": "RSA", + "alg": "RSA-OAEP", + "use": "enc", + "n": "onlqv4UZx5ZabJ3TCq-IO0s0xaOwo6fWl9o4SzLXPbGtvxonQhoYOeMlS0XkdEdLzB-eqh_hkQ", + "e": "AQAB", + "x5c": [ + "MIICnTCCAYUCBgGAmcG0xjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQ2YVaQn47Eew==" + ], + "x5t": "WKfdwdQkg", + "x5t#S256": "2_FidAw.....jlCQl20" + } + ] +} +`) + + cases := []struct { + note string + header string + expectVerifyCalls int + }{ + { + note: "verification considers only key with matching kid, if present", + header: `{"alg":"RS256", "kid": "k2"}`, + expectVerifyCalls: 1, + }, + { + note: "verification considers any key with matching alg, if no kid matches", + header: `{"alg":"RS256", "kid": "not-in-jwks"}`, + expectVerifyCalls: 2, + }, + { + note: "verification without kid considers only keys with alg matched from header", + header: `{"alg":"RS256"}`, + expectVerifyCalls: 2, + }, + { + note: "verification is is skipped if alg unknown", + header: `{"alg":"xyz"}`, + expectVerifyCalls: 0, + }, + } + + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(tc.header)) + payload := base64.RawURLEncoding.EncodeToString([]byte("{}")) + signature := base64.RawURLEncoding.EncodeToString([]byte("ignored")) + + token := ast.MustInterfaceToValue(fmt.Sprintf("%s.%s.%s", header, payload, signature)) + + verifyCalls := 0 + verifier := func(publicKey interface{}, digest []byte, signature []byte) error { + verifyCalls++ + return fmt.Errorf("fail") + } + + _, err := builtinJWTVerify(token, cert, sha256.New, verifier) + if err != nil { + t.Fatal(err) + } + + if verifyCalls != tc.expectVerifyCalls { + t.Errorf("expected %d calls to verify token, got %d", tc.expectVerifyCalls, verifyCalls) + } + }) + } +} + +func TestTopdownJWTDecodeVerifyIgnoresKeysOfUnknownAlgInJWKS(t *testing.T) { + c := ast.NewObject() + c.Insert(ast.StringTerm("cert"), ast.StringTerm(`{ + "keys": [ + { + "kty": "EC", + "use": "sig", + "crv": "P-256", + "kid": "k1", + "x": "9Qq5S5VqMQoH-FOI4atcH6V3bua03C-5ZMZMG1rszwA", + "y": "LLbFxWkGBEBrTm1GMYZJy1OXCH1KLweJMCgIEPIsibU", + "alg": "ES256" + }, + { + "kty": "RSA", + "e": "AQAB", + "use": "enc", + "kid": "k2", + "alg": "RS256", + "n": "sGu-fYVE2nq2dPxJlqAMI0Z8G3FD0XcWDnD8mkfO1ddKRGuUQZmfj4gWeZGyIk3cnuoy7KJCEqa3daXc08QHuFZyfn0rH33t8_AFsvb0q0i7R2FK-Gdqs_E0-sGpYMsRJdZWfCioLkYjIHEuVnRbi3DEsWqe484rEGbKF60jNRgGC4b-8pz-E538ZkssWxcqHrYIj5bjGEU36onjS3M_yrTuNvzv_8wRioK4fbcwmGne9bDxu8LcoSReWpPn0CnUkWnfqroRcMJnC87ZuJagDW1ZWCmU3psdsVanmFFh0DP6z0fsA4h8G2n9-qp-LEKFaWwo3IWlOsIzU3MHdcEiGw" + }, + { + "kty": "RSA", + "e": "AQAB", + "use": "enc", + "kid": "k3", + "alg": "RS256", + "n": "sGu-fYVE2nq2dPxJlqAMI0Z8G3FD0XcWDnD8mkfO1ddKRGuUQZmfj4gWeZGyIk3cnuoy7KJCEqa3daXc08QHuFZyfn0rH33t8_AFsvb0q0i7R2FK-Gdqs_E0-sGpYMsRJdZWfCioLkYjIHEuVnRbi3DEsWqe484rEGbKF60jNRgGC4b-8pz-E538ZkssWxcqHrYIj5bjGEU36onjS3M_yrTuNvzv_8wRioK4fbcwmGne9bDxu8LcoSReWpPn0CnUkWnfqroRcMJnC87ZuJagDW1ZWCmU3psdsVanmFFh0DP6z0fsA4h8G2n9-qp-LEKFaWwo3IWlOsIzU3MHdcEiGw" + }, + { + "kid": "unknown algorithm", + "kty": "RSA", + "alg": "RSA-OAEP", + "use": "enc", + "n": "onlqv4UZx5ZabJ3TCq-IO0s0xaOwo6fWl9o4SzLXPbGtvxonQhoYOeMlS0XkdEdLzB-eqh_hkQ", + "e": "AQAB", + "x5c": [ + "MIICnTCCAYUCBgGAmcG0xjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQ2YVaQn47Eew==" + ], + "x5t": "WKfdwdQkg", + "x5t#S256": "2_FidAw.....jlCQl20" + } + ] +} +`)) + + wallclock := ast.NumberTerm(int64ToJSONNumber(time.Now().UnixNano())) + constraints, err := parseTokenConstraints(c, wallclock) + if err != nil { + t.Fatal(err) + } + + if len(constraints.keys) != 3 { + t.Errorf("expected 3 keys in JWKS, got %d", len(constraints.keys)) + } + + for _, key := range constraints.keys { + if key.alg == "RSA-OAEP" { + t.Errorf("expected alg: RSA-OAEP to be removed from key set") + } + } +}