Skip to content

Commit

Permalink
Ignore keys of unknown alg when verifying JWTs with JWKS (#4725)
Browse files Browse the repository at this point in the history
Additionally, improve the JWT verification process when a JWKS
is provided:

* If `kid` is present in JWT header, and exists in JWKS — verify using
  that key only.
* If `kid` not present in JWT header, try verification only using keys
  matching the `alg` provided in the JWT header (mandatory claim).

Fixes #4699

Signed-off-by: Anders Eknert <anders@eknert.com>
  • Loading branch information
anderseknert committed Jun 2, 2022
1 parent 1889f24 commit cb6a4c0
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 29 deletions.
4 changes: 3 additions & 1 deletion internal/jwx/jwa/signature.go
Expand Up @@ -27,6 +27,7 @@ const (
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
NoValue SignatureAlgorithm = "" // No value is different from none
Unsupported SignatureAlgorithm = "unsupported"
)

// Accept is used when conversion from values given by
Expand Down Expand Up @@ -69,7 +70,8 @@ func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
}
_, ok := signatureAlg[quoted]
if !ok {
return errors.New("unknown signature algorithm")
*signature = Unsupported
return nil
}
*signature = SignatureAlgorithm(quoted)
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/jwx/jwk/jwk.go
Expand Up @@ -99,6 +99,9 @@ func parse(jwkSrc string) (*Set, error) {
} else {
for i := range rawKeySetJSON.Keys {
rawKeyJSON := rawKeySetJSON.Keys[i]
if rawKeyJSON.Algorithm != nil && *rawKeyJSON.Algorithm == jwa.Unsupported {
continue
}
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
Expand Down
8 changes: 5 additions & 3 deletions internal/jwx/jws/headers_test.go
Expand Up @@ -112,9 +112,11 @@ func TestHeader(t *testing.T) {
headers := `{"typ":"JWT",` + "\r\n" + ` "alg":"dummy"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(headers), &standardHeaders)
if err == nil {
t.Fatal("Unmarshal should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}

})
}
7 changes: 5 additions & 2 deletions internal/jwx/jws/jws_test.go
Expand Up @@ -82,8 +82,11 @@ func TestAlgError(t *testing.T) {
const hdr = `{"typ":"JWT",` + "\r\n" + ` "alg":"unknown"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(hdr), &standardHeaders)
if err == nil {
t.Fatal("header parsing should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}
})
}
Expand Down
117 changes: 97 additions & 20 deletions topdown/tokens.go
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/open-policy-agent/opa/topdown/builtins"
Expand Down Expand Up @@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error {
return fmt.Errorf("ECDSA signature verification error")
}

// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
type verificationKey struct {
alg string
kid string
key interface{}
}

// getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
// When provided a JWKS, each key additionally likely contains a key ID and the key algorithm.
func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
if block, rest := pem.Decode([]byte(certificate)); block != nil {
if len(rest) > 0 {
return nil, fmt.Errorf("extra data after a PEM certificate block")
Expand All @@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse a PEM certificate: %w", err)
}

return []interface{}{cert.PublicKey}, nil
return []verificationKey{{key: cert.PublicKey}}, nil
}

if block.Type == "PUBLIC KEY" {
Expand All @@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a PEM public key: %w", err)
}

return []interface{}{key}, nil
return []verificationKey{{key: key}}, nil
}

return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
Expand All @@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
}

var keys []interface{}
var keys []verificationKey
for _, k := range jwks.Keys {
key, err := k.Materialize()
if err != nil {
return nil, err
}
keys = append(keys, key)
keys = append(keys, verificationKey{
alg: k.GetAlgorithm().String(),
kid: k.GetKeyID(),
key: key,
})
}

return keys, nil
}

func getKeyByKid(kid string, keys []verificationKey) *verificationKey {
for _, key := range keys {
if key.kid == kid {
return &key
}
}
return nil
}

// Implements JWT signature verification.
func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
token, err := decodeJWT(a)
Expand All @@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return nil, err
}
Expand All @@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

err = token.decodeHeader()
if err != nil {
return nil, err
}
header, err := parseTokenHeader(token)
if err != nil {
return nil, err
}

// Validate the JWT signature
for _, key := range keys {
err = verify(key,
getInputSHA([]byte(token.header+"."+token.payload), hasher),
[]byte(signature))

if err == nil {
return ast.Boolean(true), nil
// First, check if there's a matching key ID (`kid`) in both token header and key(s).
// If a match is found, verify using only that key. Only applicable when a JWKS was provided.
if header.kid != "" {
if key := getKeyByKid(header.kid, keys); key != nil {
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))

return ast.Boolean(err == nil), nil
}
}

// If no key ID matched, try to verify using any key in the set
// If an alg is present in both the JWT header and the key, skip verification unless they match
for _, key := range keys {
if key.alg == "" {
// No algorithm provided for the key - this is likely a certificate and not a JWKS, so
// we'll need to verify to find out
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
} else {
if header.alg != key.alg {
continue
}
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
}
}

Expand Down Expand Up @@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast
// tokenConstraints holds decoded JWT verification constraints.
type tokenConstraints struct {
// The set of asymmetric keys we can verify with.
keys []interface{}
keys []verificationKey

// The single symmetric key we will verify with.
secret string
Expand Down Expand Up @@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error {
return fmt.Errorf("cert constraint: must be a string")
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return err
}

constraints.keys = keys
return nil
}
Expand Down Expand Up @@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
}
// If we're configured with asymmetric key(s) then only trust that
if constraints.keys != nil {
if kid != "" {
if key := getKeyByKid(kid, constraints.keys); key != nil {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err != nil {
return errSignatureNotVerified
}
return nil
}
}

verified := false
for _, key := range constraints.keys {
err := a.verify(key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
if key.alg == "" {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
} else {
if alg != key.alg {
continue
}
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
}
}

if !verified {
return errSignatureNotVerified
}
Expand Down Expand Up @@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j
return err
}
alg := standardHeaders.GetAlgorithm()
if alg == jwa.Unsupported {
return fmt.Errorf("unknown signature algorithm")
}

if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
return fmt.Errorf("type is JWT but payload is not JSON")
Expand Down

0 comments on commit cb6a4c0

Please sign in to comment.