diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index afa3ebefabb07..3c6948a3e866b 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -458,6 +458,166 @@ func TestBackend_PermittedDNSDomainsIntermediateCA(t *testing.T) { } } +func TestBackend_MetadataBasedACLPolicy(t *testing.T) { + // Start cluster with cert auth method enabled + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: log.NewNullLogger(), + CredentialBackends: map[string]logical.Factory{ + "cert": Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + client := cores[0].Client + + var err error + + // Enable the cert auth method + err = client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{ + Type: "cert", + }) + if err != nil { + t.Fatal(err) + } + + // Enable metadata in aliases + _, err = client.Logical().Write("auth/cert/config", map[string]interface{}{ + "enable_identity_alias_metadata": true, + }) + if err != nil { + t.Fatal(err) + } + + // Retrieve its accessor id + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + var accessor string + + for _, auth := range auths { + if auth.Type == "cert" { + accessor = auth.Accessor + } + } + + if accessor == "" { + t.Fatal("failed to find cert auth accessor") + } + + // Write ACL policy + err = client.Sys().PutPolicy("metadata-based", fmt.Sprintf(` +path "kv/cn/{{identity.entity.aliases.%s.metadata.common_name}}" { + capabilities = ["read"] +} +path "kv/ext/{{identity.entity.aliases.%s.metadata.2-1-1-1}}" { + capabilities = ["read"] +} +`, accessor, accessor)) + if err != nil { + t.Fatalf("err: %v", err) + } + + ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Set the trusted certificate in the backend + _, err = client.Logical().Write("auth/cert/certs/test", map[string]interface{}{ + "display_name": "test", + "policies": "metadata-based", + "certificate": string(ca), + "allowed_metadata_extensions": "2.1.1.1,1.2.3.45", + }) + if err != nil { + t.Fatal(err) + } + + // This function is a copy-paste from the NewTestCluster, with the + // modification to reconfigure the TLS on the api client with a + // specific client certificate. + getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { + transport := cleanhttp.DefaultPooledTransport() + transport.TLSClientConfig = tlsConfig.Clone() + if err := http2.ConfigureTransport(transport); err != nil { + t.Fatal(err) + } + client := &http.Client{ + Transport: transport, + CheckRedirect: func(*http.Request, []*http.Request) error { + // This can of course be overridden per-test by using its own client + return fmt.Errorf("redirects not allowed in these tests") + }, + } + config := api.DefaultConfig() + if config.Error != nil { + t.Fatal(config.Error) + } + config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) + config.HttpClient = client + + // Set the client certificates + config.ConfigureTLS(&api.TLSConfig{ + CACertBytes: cluster.CACertPEM, + ClientCert: "test-fixtures/root/rootcawextcert.pem", + ClientKey: "test-fixtures/root/rootcawextkey.pem", + }) + + apiClient, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + return apiClient + } + + // Create a new api client with the desired TLS configuration + newClient := getAPIClient(cores[0].Listeners[0].Address.Port, cores[0].TLSConfig) + + var secret *api.Secret + + secret, err = newClient.Logical().Write("auth/cert/login", map[string]interface{}{ + "name": "test", + }) + if err != nil { + t.Fatal(err) + } + if secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("expected a successful authentication") + } + + // Check paths guarded by ACL policy + newClient.SetToken(secret.Auth.ClientToken) + + _, err = newClient.Logical().Read("kv/cn/example.com") + if err != nil { + t.Fatal(err) + } + + _, err = newClient.Logical().Read("kv/cn/not.example.com") + if err == nil { + t.Fatal("expected access denied") + } + + _, err = newClient.Logical().Read("kv/ext/A UTF8String Extension") + if err != nil { + t.Fatal(err) + } + + _, err = newClient.Logical().Read("kv/ext/bar") + if err == nil { + t.Fatal("expected access denied") + } +} + func TestBackend_NonCAExpiry(t *testing.T) { var resp *logical.Response var err error