diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 5f6cbb4a80805..f04d9fad99cf6 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -9,12 +9,15 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "encoding/pem" "fmt" + "hash" "io/ioutil" "math" "math/big" @@ -30,6 +33,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/armon/go-metrics" "github.com/fatih/structs" "github.com/go-test/deep" @@ -3755,6 +3760,22 @@ func TestBackend_RevokePlusTidy_Intermediate(t *testing.T) { } func TestBackend_Root_FullCAChain(t *testing.T) { + testCases := []struct { + testName string + keyType string + }{ + {testName: "RSA", keyType: "rsa"}, + {testName: "ED25519", keyType: "ed25519"}, + {testName: "EC", keyType: "ec"}, + } + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + runFullCAChainTest(t, tc.keyType) + }) + } +} + +func runFullCAChainTest(t *testing.T, keyType string) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "pki": Factory, @@ -3783,6 +3804,7 @@ func TestBackend_Root_FullCAChain(t *testing.T) { resp, err := client.Logical().WriteWithContext(context.Background(), "pki-root/root/generate/exported", map[string]interface{}{ "common_name": "root myvault.com", + "key_type": keyType, }) if err != nil { t.Fatal(err) @@ -3821,6 +3843,7 @@ func TestBackend_Root_FullCAChain(t *testing.T) { resp, err = client.Logical().WriteWithContext(context.Background(), "pki-intermediate/intermediate/generate/exported", map[string]interface{}{ "common_name": "intermediate myvault.com", + "key_type": keyType, }) if err != nil { t.Fatal(err) @@ -3844,6 +3867,10 @@ func TestBackend_Root_FullCAChain(t *testing.T) { intermediateSignedData := resp.Data intermediateCert := intermediateSignedData["certificate"].(string) + rootCaCert := parseCert(t, rootCert) + intermediaryCaCert := parseCert(t, intermediateCert) + requireSignedBy(t, intermediaryCaCert, rootCaCert.PublicKey) + resp, err = client.Logical().WriteWithContext(context.Background(), "pki-intermediate/intermediate/set-signed", map[string]interface{}{ "certificate": intermediateCert + "\n" + rootCert + "\n", }) @@ -3905,6 +3932,26 @@ func TestBackend_Root_FullCAChain(t *testing.T) { if !strings.Contains(fullChain, rootCert) { t.Fatal("expected full chain to contain root certificate") } + + // Now issue a short-lived certificate from our pki-external. + resp, err = client.Logical().Write("pki-external/roles/example", map[string]interface{}{ + "allowed_domains": "example.com", + "allow_subdomains": "true", + "max_ttl": "1h", + }) + require.NoError(t, err, "error setting up pki role: %v", err) + + resp, err = client.Logical().Write("pki-external/issue/example", map[string]interface{}{ + "common_name": "test.example.com", + "ttl": "5m", + }) + require.NoError(t, err, "error issuing certificate: %v", err) + require.NotNil(t, resp, "got nil response from issuing request") + issueCrtAsPem := resp.Data["certificate"].(string) + issuedCrt := parseCert(t, issueCrtAsPem) + + // Verify that the certificates are signed by the intermediary CA key... + requireSignedBy(t, issuedCrt, intermediaryCaCert.PublicKey) } type MultiBool int @@ -4199,3 +4246,90 @@ var ( edCAKey string edCACert string ) + +func mountPKIEndpoint(t *testing.T, client *api.Client, path string) { + var err error + err = client.Sys().Mount(path, &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "32h", + }, + }) + require.NoError(t, err, "failed mounting pki endpoint") +} + +func requireSignedBy(t *testing.T, cert x509.Certificate, key crypto.PublicKey) { + switch key.(type) { + case *rsa.PublicKey: + requireRSASignedBy(t, cert, key.(*rsa.PublicKey)) + case *ecdsa.PublicKey: + requireECDSASignedBy(t, cert, key.(*ecdsa.PublicKey)) + case ed25519.PublicKey: + requireED25519SignedBy(t, cert, key.(ed25519.PublicKey)) + default: + require.Fail(t, "unknown public key type %#v", key) + } +} + +func requireRSASignedBy(t *testing.T, cert x509.Certificate, key *rsa.PublicKey) { + require.Contains(t, []x509.SignatureAlgorithm{x509.SHA256WithRSA, x509.SHA512WithRSA}, + cert.SignatureAlgorithm, "only sha256 signatures supported") + + var hasher hash.Hash + var hashAlgo crypto.Hash + + switch cert.SignatureAlgorithm { + case x509.SHA256WithRSA: + hasher = sha256.New() + hashAlgo = crypto.SHA256 + case x509.SHA512WithRSA: + hasher = sha512.New() + hashAlgo = crypto.SHA512 + } + + hasher.Write(cert.RawTBSCertificate) + hashData := hasher.Sum(nil) + + err := rsa.VerifyPKCS1v15(key, hashAlgo, hashData, cert.Signature) + require.NoError(t, err, "the certificate was not signed by the expected public rsa key.") +} + +func requireECDSASignedBy(t *testing.T, cert x509.Certificate, key *ecdsa.PublicKey) { + require.Contains(t, []x509.SignatureAlgorithm{x509.ECDSAWithSHA256, x509.ECDSAWithSHA512}, + cert.SignatureAlgorithm, "only ecdsa signatures supported") + + var hasher hash.Hash + switch cert.SignatureAlgorithm { + case x509.ECDSAWithSHA256: + hasher = sha256.New() + case x509.ECDSAWithSHA512: + hasher = sha512.New() + } + + hasher.Write(cert.RawTBSCertificate) + hashData := hasher.Sum(nil) + + verify := ecdsa.VerifyASN1(key, hashData, cert.Signature) + require.True(t, verify, "the certificate was not signed by the expected public ecdsa key.") +} + +func requireED25519SignedBy(t *testing.T, cert x509.Certificate, key ed25519.PublicKey) { + require.Equal(t, x509.PureEd25519, cert.SignatureAlgorithm) + ed25519.Verify(key, cert.RawTBSCertificate, cert.Signature) +} + +func parseCert(t *testing.T, pemCert string) x509.Certificate { + block, _ := pem.Decode([]byte(pemCert)) + require.NotNil(t, block, "failed to decode PEM block") + + cert, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + return *cert +} + +func requireMatchingPublicKeys(t *testing.T, cert x509.Certificate, key crypto.PublicKey) { + certPubKey := cert.PublicKey + require.True(t, reflect.DeepEqual(certPubKey, key), + "public keys mismatched: got: %v, expected: %v", certPubKey, key) +}