Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address PKI to properly support managed keys #15256

Merged
merged 3 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions builtin/logical/pki/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func Backend(conf *logical.BackendConfig) *backend {
b.tidyCASGuard = new(uint32)
b.tidyStatus = &tidyStatus{state: tidyStatusInactive}
b.storage = conf.StorageView
b.backendUuid = conf.BackendUUID

b.pkiStorageVersion.Store(0)

Expand All @@ -175,6 +176,7 @@ func Backend(conf *logical.BackendConfig) *backend {
type backend struct {
*framework.Backend

backendUuid string
storage logical.Storage
crlLifetime time.Duration
revokeStorageLock sync.RWMutex
Expand Down
131 changes: 72 additions & 59 deletions builtin/logical/pki/ca_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -38,8 +37,8 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora
`the "format" path parameter must be "pem", "der", or "pem_bundle"`)
return
}

keyType, keyBits, err := getKeyTypeAndBitsForRole(ctx, b, storage, data, mountPoint)
mkc := newManagedKeyContext(ctx, b, mountPoint)
keyType, keyBits, err := getKeyTypeAndBitsForRole(mkc, storage, data)
if err != nil {
errorResp = logical.ErrorResponse(err.Error())
return
Expand Down Expand Up @@ -77,31 +76,68 @@ func (b *backend) getGenerationParams(ctx context.Context, storage logical.Stora

func generateCABundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) {
if kmsRequested(input) {
return generateManagedKeyCABundle(ctx, b, input, data, randomSource)
keyId, err := getManagedKeyId(input.apiData)
if err != nil {
return nil, err
}
return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource)
}

if existingKeyRequested(input) {
keyRef, err := getKeyRefWithErr(input.apiData)
if err != nil {
return nil, err
}
return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef))

keyEntry, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef)
if err != nil {
return nil, err
}

if keyEntry.isManagedPrivateKey() {
keyId, err := keyEntry.getManagedKeyUUID()
if err != nil {
return nil, err
}
return generateManagedKeyCABundle(ctx, b, input, keyId, data, randomSource)
}

return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry))
}

return certutil.CreateCertificateWithRandomSource(data, randomSource)
}

func generateCSRBundle(ctx context.Context, b *backend, input *inputBundle, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (*certutil.ParsedCSRBundle, error) {
if kmsRequested(input) {
return generateManagedKeyCSRBundle(ctx, b, input, data, addBasicConstraints, randomSource)
keyId, err := getManagedKeyId(input.apiData)
if err != nil {
return nil, err
}

return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource)
}

if existingKeyRequested(input) {
keyRef, err := getKeyRefWithErr(input.apiData)
if err != nil {
return nil, err
}
return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingGeneratePrivateKey(ctx, input.req.Storage, keyRef))

key, err := getExistingKeyFromRef(ctx, input.req.Storage, keyRef)
if err != nil {
return nil, err
}

if key.isManagedPrivateKey() {
keyId, err := key.getManagedKeyUUID()
if err != nil {
return nil, err
}
return generateManagedKeyCSRBundle(ctx, b, input, keyId, data, addBasicConstraints, randomSource)
}

return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key))
}

return certutil.CreateCSRWithRandomSource(data, addBasicConstraints, randomSource)
Expand All @@ -114,7 +150,7 @@ func parseCABundle(ctx context.Context, b *backend, req *logical.Request, bundle
return bundle.ToParsedCertBundle()
}

func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.Storage, data *framework.FieldData, mountPoint string) (string, int, error) {
func getKeyTypeAndBitsForRole(mkc managedKeyContext, storage logical.Storage, data *framework.FieldData) (string, int, error) {
exportedStr := data.Get("exported").(string)
var keyType string
var keyBits int
Expand All @@ -138,103 +174,80 @@ func getKeyTypeAndBitsForRole(ctx context.Context, b *backend, storage logical.S

var pubKey crypto.PublicKey
if kmsRequestedFromFieldData(data) {
pubKeyManagedKey, err := getManagedKeyPublicKey(ctx, b, data, mountPoint)
keyId, err := getManagedKeyId(data)
if err != nil {
return "", 0, errors.New("unable to determine managed key id" + err.Error())
}

pubKeyManagedKey, err := getManagedKeyPublicKey(mkc, keyId)
if err != nil {
return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
}
pubKey = pubKeyManagedKey
}

if existingKeyRequestedFromFieldData(data) {
existingPubKey, err := getExistingPublicKey(ctx, storage, data)
existingPubKey, err := getExistingPublicKey(mkc, storage, data)
if err != nil {
return "", 0, errors.New("failed to lookup public key from existing key: " + err.Error())
}
pubKey = existingPubKey
}

return getKeyTypeAndBitsFromPublicKeyForRole(pubKey)
privateKeyType, keyBits, err := getKeyTypeAndBitsFromPublicKeyForRole(pubKey)
return string(privateKeyType), keyBits, err
}

func getExistingPublicKey(ctx context.Context, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
func getExistingPublicKey(mkc managedKeyContext, s logical.Storage, data *framework.FieldData) (crypto.PublicKey, error) {
keyRef, err := getKeyRefWithErr(data)
if err != nil {
return nil, err
}
id, err := resolveKeyReference(ctx, s, keyRef)
if err != nil {
return nil, err
}
key, err := fetchKeyById(ctx, s, id)
id, err := resolveKeyReference(mkc.ctx, s, keyRef)
if err != nil {
return nil, err
}
signer, err := key.GetSigner()
key, err := fetchKeyById(mkc.ctx, s, id)
if err != nil {
return nil, err
}
return signer.Public(), nil
return getPublicKey(mkc, key)
}

func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (string, int, error) {
var keyType string
func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.PrivateKeyType, int, error) {
var keyType certutil.PrivateKeyType
var keyBits int

switch pubKey.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyType = certutil.RSAPrivateKey
keyBits = certutil.GetPublicKeySize(pubKey)
case *ecdsa.PublicKey:
keyType = "ec"
keyType = certutil.ECPrivateKey
case *ed25519.PublicKey:
keyType = "ed25519"
keyType = certutil.Ed25519PrivateKey
default:
return "", 0, fmt.Errorf("unsupported public key: %#v", pubKey)
return certutil.UnknownPrivateKey, 0, fmt.Errorf("unsupported public key: %#v", pubKey)
}
return keyType, keyBits, nil
}

func getManagedKeyPublicKey(ctx context.Context, b *backend, data *framework.FieldData, mountPoint string) (crypto.PublicKey, error) {
keyId, err := getManagedKeyId(data)
func getExistingKeyFromRef(ctx context.Context, s logical.Storage, keyRef string) (*keyEntry, error) {
keyId, err := resolveKeyReference(ctx, s, keyRef)
if err != nil {
return nil, errors.New("unable to determine managed key id")
}
// Determine key type and key bits from the managed public key
var pubKey crypto.PublicKey
err = withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
pubKey, err = key.GetPublicKey(ctx)
if err != nil {
return err
}

return nil
})
if err != nil {
return nil, errors.New("failed to lookup public key from managed key: " + err.Error())
return nil, err
}
return pubKey, nil
return fetchKeyById(ctx, s, keyId)
}

func existingGeneratePrivateKey(ctx context.Context, s logical.Storage, keyRef string) certutil.KeyGenerator {
return func(keyType string, keyBits int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
keyId, err := resolveKeyReference(ctx, s, keyRef)
if err != nil {
return err
}
key, err := fetchKeyById(ctx, s, keyId)
func existingKeyGeneratorFromBytes(key *keyEntry) certutil.KeyGenerator {
return func(_ string, _ int, container certutil.ParsedPrivateKeyContainer, _ io.Reader) error {
signer, _, pemBytes, err := getSignerFromKeyEntryBytes(key)
if err != nil {
return err
}
signer, err := key.GetSigner()
if err != nil {
return err
}
privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer)
if privateKeyType == certutil.UnknownPrivateKey {
return errors.New("unknown private key type loaded from key id: " + keyId.String())
}
blk, _ := pem.Decode([]byte(key.PrivateKey))
container.SetParsedPrivateKey(signer, privateKeyType, blk.Bytes)

container.SetParsedPrivateKey(signer, key.PrivateKeyType, pemBytes.Bytes)
return nil
}
}
3 changes: 2 additions & 1 deletion builtin/logical/pki/crl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ func TestBackend_CRL_EnableDisable(t *testing.T) {
func TestBackend_Secondary_CRL_Rebuilding(t *testing.T) {
ctx := context.Background()
b, s := createBackendWithStorage(t)
mkc := newManagedKeyContext(ctx, b, "test")

// Write out the issuer/key to storage without going through the api call as replication would.
bundle := genCertBundle(t, b, s)
issuer, _, err := writeCaBundle(ctx, s, bundle, "", "")
issuer, _, err := writeCaBundle(mkc, s, bundle, "", "")
require.NoError(t, err)

// Just to validate, before we call the invalidate function, make sure our CRL has not been generated
Expand Down
1 change: 1 addition & 0 deletions builtin/logical/pki/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const (
keyRefParam = "key_ref"
keyIdParam = "key_id"
keyTypeParam = "key_type"
keyBitsParam = "key_bits"
)

// addIssueAndSignCommonFields adds fields common to both CA and non-CA issuing
Expand Down
116 changes: 116 additions & 0 deletions builtin/logical/pki/key_util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package pki

import (
"context"
"crypto"
"encoding/pem"
"errors"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/logical"
)

type managedKeyContext struct {
ctx context.Context
b *backend
mountPoint string
}

func newManagedKeyContext(ctx context.Context, b *backend, mountPoint string) managedKeyContext {
return managedKeyContext{
ctx: ctx,
b: b,
mountPoint: mountPoint,
}
}

func comparePublicKey(ctx managedKeyContext, key *keyEntry, publicKey crypto.PublicKey) (bool, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it makes sense to add a "fast path" and compare on bytes? Are there any cases where two bytes are equal but their public keys would be unequal?

That would avoid having to (unnecessarily) parse string->PEM->ASN.1->GoStruct blobs repeatedly. Just a thought.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah as you pointed out below not sure if it's worth it at least for this round to add in a fast path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah that's fair.

publicKeyForKeyEntry, err := getPublicKey(ctx, key)
if err != nil {
return false, err
}

return certutil.ComparePublicKeysAndType(publicKeyForKeyEntry, publicKey)
}

func getPublicKey(mkc managedKeyContext, key *keyEntry) (crypto.PublicKey, error) {
if key.PrivateKeyType == certutil.ManagedPrivateKey {
keyId, err := extractManagedKeyId([]byte(key.PrivateKey))
if err != nil {
return nil, err
}
return getManagedKeyPublicKey(mkc, keyId)
}

signer, _, _, err := getSignerFromKeyEntryBytes(key)
if err != nil {
return nil, err
}
return signer.Public(), nil
}

func getSignerFromKeyEntryBytes(key *keyEntry) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
if key.PrivateKeyType == certutil.UnknownPrivateKey {
return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("unsupported unknown private key type for key: %s (%s)", key.ID, key.Name)}
}

if key.PrivateKeyType == certutil.ManagedPrivateKey {
return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("can not get a signer from a managed key: %s (%s)", key.ID, key.Name)}
}

bytes, blockType, blk, err := getSignerFromBytes([]byte(key.PrivateKey))
if err != nil {
return nil, certutil.UnknownBlock, nil, errutil.InternalError{Err: fmt.Sprintf("failed parsing key entry bytes for key id: %s (%s): %s", key.ID, key.Name, err.Error())}
}

return bytes, blockType, blk, nil
}

func getSignerFromBytes(keyBytes []byte) (crypto.Signer, certutil.BlockType, *pem.Block, error) {
pemBlock, _ := pem.Decode(keyBytes)
if pemBlock == nil {
return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: "no data found in PEM block"}
}

signer, blk, err := certutil.ParseDERKey(pemBlock.Bytes)
if err != nil {
return nil, certutil.UnknownBlock, pemBlock, errutil.InternalError{Err: fmt.Sprintf("failed to parse PEM block: %s", err.Error())}
}
return signer, blk, pemBlock, nil
}

func getManagedKeyPublicKey(mkc managedKeyContext, keyId managedKeyId) (crypto.PublicKey, error) {
// Determine key type and key bits from the managed public key
var pubKey crypto.PublicKey
err := withManagedPKIKey(mkc.ctx, mkc.b, keyId, mkc.mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
var myErr error
pubKey, myErr = key.GetPublicKey(ctx)
if myErr != nil {
return myErr
}

return nil
})
if err != nil {
return nil, errors.New("failed to lookup public key from managed key: " + err.Error())
}
return pubKey, nil
}

func importKeyFromBytes(mkc managedKeyContext, s logical.Storage, keyValue string, keyName string) (*keyEntry, bool, error) {
signer, _, _, err := getSignerFromBytes([]byte(keyValue))
if err != nil {
return nil, false, err
}
privateKeyType := certutil.GetPrivateKeyTypeFromSigner(signer)
if privateKeyType == certutil.UnknownPrivateKey {
return nil, false, errors.New("unsupported private key type within pem bundle")
}

key, existed, err := importKey(mkc, s, keyValue, keyName, privateKeyType)
if err != nil {
return nil, false, err
}
return key, existed, nil
}