Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore keys of unknown alg when verifying JWTs with JWKS #4725

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/jwx/jwa/signature.go
Expand Up @@ -27,6 +27,7 @@ const (
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
NoValue SignatureAlgorithm = "" // No value is different from none
Unsupported SignatureAlgorithm = "unsupported"
)

// Accept is used when conversion from values given by
Expand Down Expand Up @@ -69,7 +70,8 @@ func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
}
_, ok := signatureAlg[quoted]
if !ok {
return errors.New("unknown signature algorithm")
*signature = Unsupported
return nil
anderseknert marked this conversation as resolved.
Show resolved Hide resolved
}
*signature = SignatureAlgorithm(quoted)
return nil
Expand Down
3 changes: 3 additions & 0 deletions internal/jwx/jwk/jwk.go
Expand Up @@ -99,6 +99,9 @@ func parse(jwkSrc string) (*Set, error) {
} else {
for i := range rawKeySetJSON.Keys {
rawKeyJSON := rawKeySetJSON.Keys[i]
if rawKeyJSON.Algorithm != nil && *rawKeyJSON.Algorithm == jwa.Unsupported {
continue
}
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
Expand Down
8 changes: 5 additions & 3 deletions internal/jwx/jws/headers_test.go
Expand Up @@ -112,9 +112,11 @@ func TestHeader(t *testing.T) {
headers := `{"typ":"JWT",` + "\r\n" + ` "alg":"dummy"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(headers), &standardHeaders)
if err == nil {
t.Fatal("Unmarshal should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}

})
}
7 changes: 5 additions & 2 deletions internal/jwx/jws/jws_test.go
Expand Up @@ -82,8 +82,11 @@ func TestAlgError(t *testing.T) {
const hdr = `{"typ":"JWT",` + "\r\n" + ` "alg":"unknown"}`
var standardHeaders jws.StandardHeaders
err := json.Unmarshal([]byte(hdr), &standardHeaders)
if err == nil {
t.Fatal("header parsing should have failed")
if err != nil {
t.Fatal(err)
}
if standardHeaders.Algorithm != jwa.Unsupported {
t.Errorf("expected unsupported algorithm")
}
})
}
Expand Down
117 changes: 97 additions & 20 deletions topdown/tokens.go
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/open-policy-agent/opa/topdown/builtins"
Expand Down Expand Up @@ -268,9 +269,16 @@ func verifyES(publicKey interface{}, digest []byte, signature []byte) error {
return fmt.Errorf("ECDSA signature verification error")
}

// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
type verificationKey struct {
alg string
kid string
key interface{}
}

// getKeysFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
// When provided a JWKS, each key additionally likely contains a key ID and the key algorithm.
func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
if block, rest := pem.Decode([]byte(certificate)); block != nil {
if len(rest) > 0 {
return nil, fmt.Errorf("extra data after a PEM certificate block")
Expand All @@ -281,8 +289,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse a PEM certificate: %w", err)
}

return []interface{}{cert.PublicKey}, nil
return []verificationKey{{key: cert.PublicKey}}, nil
}

if block.Type == "PUBLIC KEY" {
Expand All @@ -291,7 +298,7 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a PEM public key: %w", err)
}

return []interface{}{key}, nil
return []verificationKey{{key: key}}, nil
}

return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
Expand All @@ -302,18 +309,31 @@ func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
}

var keys []interface{}
var keys []verificationKey
for _, k := range jwks.Keys {
key, err := k.Materialize()
if err != nil {
return nil, err
}
keys = append(keys, key)
keys = append(keys, verificationKey{
alg: k.GetAlgorithm().String(),
kid: k.GetKeyID(),
key: key,
})
}

return keys, nil
}

func getKeyByKid(kid string, keys []verificationKey) *verificationKey {
for _, key := range keys {
if key.kid == kid {
return &key
}
}
return nil
}

// Implements JWT signature verification.
func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
token, err := decodeJWT(a)
Expand All @@ -326,7 +346,7 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return nil, err
}
Expand All @@ -336,14 +356,45 @@ func builtinJWTVerify(a ast.Value, b ast.Value, hasher func() hash.Hash, verify
return nil, err
}

err = token.decodeHeader()
if err != nil {
return nil, err
}
header, err := parseTokenHeader(token)
if err != nil {
return nil, err
}

// Validate the JWT signature
for _, key := range keys {
err = verify(key,
getInputSHA([]byte(token.header+"."+token.payload), hasher),
[]byte(signature))

if err == nil {
return ast.Boolean(true), nil
// First, check if there's a matching key ID (`kid`) in both token header and key(s).
// If a match is found, verify using only that key. Only applicable when a JWKS was provided.
if header.kid != "" {
if key := getKeyByKid(header.kid, keys); key != nil {
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))

return ast.Boolean(err == nil), nil
}
}

// If no key ID matched, try to verify using any key in the set
// If an alg is present in both the JWT header and the key, skip verification unless they match
for _, key := range keys {
if key.alg == "" {
// No algorithm provided for the key - this is likely a certificate and not a JWKS, so
// we'll need to verify to find out
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
} else {
if header.alg != key.alg {
continue
}
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
}
}
anderseknert marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -445,7 +496,7 @@ func builtinJWTVerifyHS512(bctx BuiltinContext, args []*ast.Term, iter func(*ast
// tokenConstraints holds decoded JWT verification constraints.
type tokenConstraints struct {
// The set of asymmetric keys we can verify with.
keys []interface{}
keys []verificationKey

// The single symmetric key we will verify with.
secret string
Expand Down Expand Up @@ -495,10 +546,11 @@ func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) error {
return fmt.Errorf("cert constraint: must be a string")
}

keys, err := getKeyFromCertOrJWK(string(s))
keys, err := getKeysFromCertOrJWK(string(s))
if err != nil {
return err
}

constraints.keys = keys
return nil
}
Expand Down Expand Up @@ -595,14 +647,36 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
}
// If we're configured with asymmetric key(s) then only trust that
if constraints.keys != nil {
if kid != "" {
if key := getKeyByKid(kid, constraints.keys); key != nil {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err != nil {
return errSignatureNotVerified
}
return nil
}
}

verified := false
for _, key := range constraints.keys {
err := a.verify(key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
if key.alg == "" {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
} else {
if alg != key.alg {
continue
}
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
}
}

if !verified {
return errSignatureNotVerified
}
Expand Down Expand Up @@ -843,6 +917,9 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j
return err
}
alg := standardHeaders.GetAlgorithm()
if alg == jwa.Unsupported {
return fmt.Errorf("unknown signature algorithm")
}

if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
return fmt.Errorf("type is JWT but payload is not JSON")
Expand Down