Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oidc: check for nil signing key on rotation #13716

Merged
merged 4 commits into from Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/13716.txt
@@ -0,0 +1,3 @@
```release-note:bug
identity/oidc: Check for a nil signing key on rotation to prevent panics.
```
55 changes: 37 additions & 18 deletions vault/identity_store_oidc.go
Expand Up @@ -548,19 +548,11 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica

// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
err = key.generateAndSetKey(ctx, i.Logger(), req.Storage)
if err != nil {
return nil, err
}

key.SigningKey = signingKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)

err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1013,6 +1005,24 @@ func mergeJSONTemplates(logger hclog.Logger, output map[string]interface{}, temp
return nil
}

// generateAndSetKey will generate new signing and public key pairs and set
// them as the SigningKey.
func (k *namedKey) generateAndSetKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
signingKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}

k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
return nil
}

// generateAndSetNextKey will generate new signing and public key pairs and set
// them as the NextSigningKey.
func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
Expand Down Expand Up @@ -1481,8 +1491,25 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL
if k.SigningKey == nil {
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
logger.Debug("nil signing key detected on rotation")
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
err := k.generateAndSetKey(ctx, logger, s)
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
}

if k.NextSigningKey == nil {
logger.Debug("nil next signing key detected on rotation")
// keys will not have a NextSigningKey if they were generated before
// vault 1.9
err := k.generateAndSetNextKey(ctx, logger, s)
if err != nil {
return err
}
}

verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}
Expand All @@ -1496,14 +1523,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St
}
}

if k.NextSigningKey == nil {
// keys will not have a NextSigningKey if they were generated before
// vault 1.9
err := k.generateAndSetNextKey(ctx, logger, s)
if err != nil {
return err
}
}
// do the rotation
k.SigningKey = k.NextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)
Expand Down
60 changes: 49 additions & 11 deletions vault/identity_store_oidc_test.go
Expand Up @@ -937,12 +937,32 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
numPublicKeys int
}{
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
{1, 2, 2},
{2, 3, 3},
{3, 3, 3},
{4, 3, 3},
{5, 3, 3},
{6, 3, 3},
{7, 3, 3},
{2, 2, 2},
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
{3, 2, 2},
{4, 2, 2},
{5, 2, 2},
{6, 2, 2},
{7, 2, 2},
},
},
{
// don't set SigningKey to ensure its non-existence can be handled
&namedKey{
name: "test-key-nil-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: append([]*expireableKey{}, &expireableKey{KeyID: id}),
SigningKey: nil,
NextSigningKey: jwk,
NextRotation: time.Now(),
},
[]struct {
cycle int
numKeys int
numPublicKeys int
}{
{1, 2, 2},
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
},
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
},
}
Expand Down Expand Up @@ -985,15 +1005,33 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
}

// measure collected samples
for i := range testSet.testCases {
for i, tc := range testSet.testCases {
namedKeySamples[i].DecodeJSON(&testSet.namedKey)
if len(testSet.namedKey.KeyRing) != testSet.testCases[i].numKeys {
t.Fatalf("At cycle: %d expected namedKey's KeyRing to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numKeys, len(testSet.namedKey.KeyRing))
actualKeyRingLen := len(testSet.namedKey.KeyRing)
if actualKeyRingLen < tc.numKeys {
t.Fatalf(
"For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numKeys,
actualKeyRingLen,
)
}
if len(publicKeysSamples[i]) != testSet.testCases[i].numPublicKeys {
t.Fatalf("At cycle: %d expected public keys to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numPublicKeys, len(publicKeysSamples[i]))
actualPubKeysLen := len(publicKeysSamples[i])
if actualPubKeysLen < tc.numPublicKeys {
t.Fatalf(
"For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d",
testSet.namedKey.name,
tc.cycle,
tc.numPublicKeys,
actualPubKeysLen,
)
}
}

if err := storage.Delete(ctx, namedKeyConfigPath+testSet.namedKey.name); err != nil {
t.Fatalf("deleting from in mem storage failed")
}
}
}

Expand Down