Skip to content

Commit

Permalink
Ignore keys of unknown alg when verifying JWTs with JWKS
Browse files Browse the repository at this point in the history
Additionally, improve the JWT verification process when a JWKS
is provided:

* If `kid` is present in JWT header, and exists in JWKS — verify using
  that key only.
* If `kid` not present in JWT header, try verification only using keys
  matching the `alg` provided in the JWT header (mandatory claim).

Fixes open-policy-agent#4699

Signed-off-by: Anders Eknert <anders@eknert.com>
  • Loading branch information
anderseknert committed Jun 1, 2022
1 parent 1889f24 commit 5a23fb0
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 29 deletions.
4 changes: 3 additions & 1 deletion internal/jwx/jwa/signature.go
Expand Up @@ -27,6 +27,7 @@ const (
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
NoValue SignatureAlgorithm = "" // No value is different from none
Unsupported SignatureAlgorithm = "unsupported"
)

// Accept is used when conversion from values given by
Expand Down Expand Up @@ -69,7 +70,8 @@ func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
}
_, ok := signatureAlg[quoted]
if !ok {
return errors.New("unknown signature algorithm")
*signature = Unsupported
return nil
}
*signature = SignatureAlgorithm(quoted)
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/jwx/jwk/jwk.go
Expand Up @@ -99,6 +99,9 @@ func parse(jwkSrc string) (*Set, error) {
} else {
for i := range rawKeySetJSON.Keys {
rawKeyJSON := rawKeySetJSON.Keys[i]
if rawKeyJSON.Algorithm != nil && *rawKeyJSON.Algorithm == jwa.Unsupported {
continue
}
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
Expand Down
8 changes: 5 additions & 3 deletions internal/jwx/jws/headers_test.go
Expand Up @@ -112,9 +112,11 @@ func TestHeader(t *testing.T) {
headers := `{"typ":"JWT",` + "\r\n" + ` "alg":"dummy"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(headers), &standardHeaders)
if err == nil {
t.Fatal("Unmarshal should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}

})
}
7 changes: 5 additions & 2 deletions internal/jwx/jws/jws_test.go
Expand Up @@ -82,8 +82,11 @@ func TestAlgError(t *testing.T) {
const hdr = `{"typ":"JWT",` + "\r\n" + ` "alg":"unknown"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(hdr), &standardHeaders)
if err == nil {
t.Fatal("header parsing should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}
})
}
Expand Down
117 changes: 97 additions & 20 deletions topdown/tokens.go
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/open-policy-agent/opa/topdown/builtins"
Expand Down Expand Up @@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error {
return fmt.Errorf("ECDSA signature verification error")
}

// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
type verificationKey struct {
alg string
kid string
key interface{}
}

// getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
// When provided a JWKS, each key additionally likely contains a key ID and the key algorithm.
func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
if block, rest := pem.Decode([]byte(certificate)); block != nil {
if len(rest) > 0 {
return nil, fmt.Errorf("extra data after a PEM certificate block")
Expand All @@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse a PEM certificate: %w", err)
}

return []interface{}{cert.PublicKey}, nil
return []verificationKey{{key: cert.PublicKey}}, nil
}

if block.Type == "PUBLIC KEY" {
Expand All @@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a PEM public key: %w", err)
}

return []interface{}{key}, nil
return []verificationKey{{key: key}}, nil
}

return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
Expand All @@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
}

var keys []interface{}
var keys []verificationKey
for _, k := range jwks.Keys {
key, err := k.Materialize()
if err != nil {
return nil, err
}
keys = append(keys, key)
keys = append(keys, verificationKey{
alg: k.GetAlgorithm().String(),
kid: k.GetKeyID(),
key: key,
})
}

return keys, nil
}

func getKeyByKid(kid string, keys []verificationKey) *verificationKey {
for _, key := range keys {
if key.kid == kid {
return &key
}
}
return nil
}

// Implements JWT signature verification.
func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
token, err := decodeJWT(a)
Expand All @@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return nil, err
}
Expand All @@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

err = token.decodeHeader()
if err != nil {
return nil, err
}
header, err := parseTokenHeader(token)
if err != nil {
return nil, err
}

// Validate the JWT signature
for _, key := range keys {
err = verify(key,
getInputSHA([]byte(token.header+"."+token.payload), hasher),
[]byte(signature))

if err == nil {
return ast.Boolean(true), nil
// First, check if there's a matching key ID (`kid`) in both token header and key(s).
// If a match is found, verify using only that key. Only applicable when a JWKS was provided.
if header.kid != "" {
if key := getKeyByKid(header.kid, keys); key != nil {
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))

return ast.Boolean(err == nil), nil
}
}

// If no key ID matched, try to verify using any key in the set
// If an alg is present in both the JWT header and the key, skip verification unless they match
for _, key := range keys {
if key.alg == "" {
// No algorithm provided for the key - this is likely a certificate and not a JWKS, so
// we'll need to verify to find out
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
} else {
if header.alg != key.alg {
continue
}
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
}
}

Expand Down Expand Up @@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast
// tokenConstraints holds decoded JWT verification constraints.
type tokenConstraints struct {
// The set of asymmetric keys we can verify with.
keys []interface{}
keys []verificationKey

// The single symmetric key we will verify with.
secret string
Expand Down Expand Up @@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error {
return fmt.Errorf("cert constraint: must be a string")
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return err
}

constraints.keys = keys
return nil
}
Expand Down Expand Up @@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
}
// If we're configured with asymmetric key(s) then only trust that
if constraints.keys != nil {
if kid != "" {
if key := getKeyByKid(kid, constraints.keys); key != nil {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err != nil {
return errSignatureNotVerified
}
return nil
}
}

verified := false
for _, key := range constraints.keys {
err := a.verify(key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
if key.alg == "" {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
} else {
if alg != key.alg {
continue
}
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
}
}

if !verified {
return errSignatureNotVerified
}
Expand Down Expand Up @@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j
return err
}
alg := standardHeaders.GetAlgorithm()
if alg == jwa.Unsupported {
return fmt.Errorf("unknown signature algorithm")
}

if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
return fmt.Errorf("type is JWT but payload is not JSON")
Expand Down

0 comments on commit 5a23fb0

Please sign in to comment.