Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make key completion work for both kv-v1 and kv-v2 #16553

Merged
merged 2 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/16553.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
command: Fix shell completion for KV v2 mounts
```
56 changes: 47 additions & 9 deletions command/base_predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,27 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {

// Trim path with potential mount
var relativePath string
for _, mount := range p.mounts() {
mountInfos, err := p.mountInfos()
if err != nil {
return nil
}

var mountType, mountVersion string
for mount, mountInfo := range mountInfos {
if strings.HasPrefix(path, mount) {
relativePath = strings.TrimPrefix(path, mount+"/")
mountType = mountInfo.Type
if mountInfo.Options != nil {
mountVersion = mountInfo.Options["version"]
}
break
}
}

// Predict path or mount depending on path separator
var predictions []string
if strings.Contains(relativePath, "/") {
predictions = p.paths(path, includeFiles)
predictions = p.paths(mountType, mountVersion, path, includeFiles)
} else {
predictions = p.filter(p.mounts(), path)
}
Expand Down Expand Up @@ -288,7 +298,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
}

// paths predicts all paths which start with the given path.
func (p *Predict) paths(path string, includeFiles bool) []string {
func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string {
client := p.Client()
if client == nil {
return nil
Expand All @@ -303,7 +313,7 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
root = root[:idx+1]
}

paths := p.listPaths(root)
paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion))

var predictions []string
for _, p := range paths {
Expand All @@ -326,6 +336,22 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
return predictions
}

func buildAPIListPath(path, mountType, mountVersion string) string {
if mountType == "kv" && mountVersion == "2" {
return toKVv2ListPath(path)
}
return path
}

func toKVv2ListPath(path string) string {
firstSlashIdx := strings.Index(path, "/")
if firstSlashIdx < 0 {
return path
}

return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:]
}

// audits returns a sorted list of the audit backends for Vault server for
// which the client is configured to communicate with.
func (p *Predict) audits() []string {
Expand Down Expand Up @@ -421,16 +447,28 @@ func (p *Predict) policies() []string {
return policies
}

// mounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs.
func (p *Predict) mounts() []string {
// mountInfos returns a map with mount paths as keys and MountOutputs as values
// for the Vault server which the client is configured to communicate with.
// Returns error if server communication fails.
func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) {
peteski22 marked this conversation as resolved.
Show resolved Hide resolved
client := p.Client()
if client == nil {
return nil
return nil, nil
}

mounts, err := client.Sys().ListMounts()
if err != nil {
return nil, err
}

return mounts, nil
}

// mounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs.
func (p *Predict) mounts() []string {
mounts, err := p.mountInfos()
if err != nil {
return defaultPredictVaultMounts
}
Expand Down
75 changes: 74 additions & 1 deletion command/base_predict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,80 @@ func TestPredict_Paths(t *testing.T) {
p := NewPredict()
p.client = client

act := p.paths(tc.path, tc.includeFiles)
act := p.paths("kv", "1", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}

func TestPredict_PathsKVv2(t *testing.T) {
t.Parallel()

client, closer := testVaultServerWithKVVersion(t, "2")
defer closer()

data := map[string]interface{}{"data": map[string]interface{}{"a": "b"}}
if _, err := client.Logical().Write("secret/data/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/foo", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/zip/zap", data); err != nil {
t.Fatal(err)
}

cases := []struct {
name string
path string
includeFiles bool
exp []string
}{
{
"bad_path",
"nope/not/a/real/path/ever",
true,
[]string{"nope/not/a/real/path/ever"},
},
{
"good_path",
"secret/",
true,
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"good_path_no_files",
"secret/",
false,
[]string{"secret/zip/"},
},
{
"partial_match",
"secret/z",
true,
[]string{"secret/zip/"},
},
{
"partial_match_no_files",
"secret/z",
false,
[]string{"secret/zip/"},
},
}

t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

p := NewPredict()
p.client = client

act := p.paths("kv", "2", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
Expand Down
31 changes: 25 additions & 6 deletions command/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ func testVaultServer(tb testing.TB) (*api.Client, func()) {
return client, closer
}

func testVaultServerWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, func()) {
tb.Helper()

client, _, closer := testVaultServerUnsealWithKVVersion(tb, kvVersion)
return client, closer
}

func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
tb.Helper()

Expand All @@ -85,21 +92,29 @@ func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
// testVaultServerUnseal creates a test vault cluster and returns a configured
// API client, list of unseal keys (as strings), and a closer function.
func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
return testVaultServerUnsealWithKVVersion(tb, "1")
}

func testVaultServerUnsealWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, []string, func()) {
tb.Helper()
logger := log.NewInterceptLogger(&log.LoggerOptions{
Output: log.DefaultOutput,
Level: log.Debug,
JSONFormat: logging.ParseEnvLogFormat() == logging.JSONFormat,
})

return testVaultServerCoreConfig(tb, &vault.CoreConfig{
return testVaultServerCoreConfigWithOpts(tb, &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logger,
CredentialBackends: defaultVaultCredentialBackends,
AuditBackends: defaultVaultAuditBackends,
LogicalBackends: defaultVaultLogicalBackends,
BuiltinRegistry: builtinplugins.Registry,
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1,
KVVersion: kvVersion,
})
}

Expand All @@ -121,15 +136,19 @@ func testVaultServerPluginDir(tb testing.TB, pluginDir string) (*api.Client, []s
})
}

// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) {
tb.Helper()

cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, &vault.TestClusterOptions{
return testVaultServerCoreConfigWithOpts(tb, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1, // Default is 3, but we don't need that many
})
}

// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfigWithOpts(tb testing.TB, coreConfig *vault.CoreConfig, opts *vault.TestClusterOptions) (*api.Client, []string, func()) {
tb.Helper()

cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, opts)
cluster.Start()

// Make it easy to get access to the active
Expand Down
2 changes: 1 addition & 1 deletion command/kv_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *KVGetCommand) Flags() *FlagSets {
}

func (c *KVGetCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFiles()
}

func (c *KVGetCommand) AutocompleteFlags() complete.Flags {
Expand Down
2 changes: 1 addition & 1 deletion command/kv_patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (c *KVPatchCommand) Flags() *FlagSets {
}

func (c *KVPatchCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFiles()
}

func (c *KVPatchCommand) AutocompleteFlags() complete.Flags {
Expand Down
2 changes: 1 addition & 1 deletion command/kv_put.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *KVPutCommand) Flags() *FlagSets {
}

func (c *KVPutCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFolders()
}

func (c *KVPutCommand) AutocompleteFlags() complete.Flags {
Expand Down
8 changes: 7 additions & 1 deletion vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ type TestClusterOptions struct {
// this stores the vault version that should be used for each core config
VersionMap map[int]string
RedundancyZoneMap map[int]string
KVVersion string
}

var DefaultNumCores = 3
Expand Down Expand Up @@ -2030,6 +2031,11 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit

TestWaitActive(t, leader.Core)

kvVersion := "1"
if opts != nil {
kvVersion = opts.KVVersion
}

// Existing tests rely on this; we can make a toggle to disable it
// later if we want
kvReq := &logical.Request{
Expand All @@ -2041,7 +2047,7 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit
"path": "secret/",
"description": "key/value secret storage",
"options": map[string]string{
"version": "1",
"version": kvVersion,
},
},
}
Expand Down