Skip to content

Commit

Permalink
Make key completion work for both kv-v1 and kv-v2
Browse files Browse the repository at this point in the history
Co-authored-by: Kieron Browne <kbrowne@vmware.com>
Co-authored-by: Georgi Sabev <georgethebeatle@gmail.com>
Co-authored-by: Danail Branekov <danailster@gmail.com>
  • Loading branch information
3 people committed Aug 3, 2022
1 parent 637d4bd commit a916d75
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 21 deletions.
3 changes: 3 additions & 0 deletions changelog/16553.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
command: Fix shell completion for KV v2 mounts
```
53 changes: 44 additions & 9 deletions command/base_predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,27 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {

// Trim path with potential mount
var relativePath string
for _, mount := range p.mounts() {
mountInfos, err := p.mountInfos()
if err != nil {
return nil
}

var mountType, mountVersion string
for mount, mountInfo := range mountInfos {
if strings.HasPrefix(path, mount) {
relativePath = strings.TrimPrefix(path, mount+"/")
mountType = mountInfo.Type
if mountInfo.Options != nil {
mountVersion = mountInfo.Options["version"]
}
break
}
}

// Predict path or mount depending on path separator
var predictions []string
if strings.Contains(relativePath, "/") {
predictions = p.paths(path, includeFiles)
predictions = p.paths(mountType, mountVersion, path, includeFiles)
} else {
predictions = p.filter(p.mounts(), path)
}
Expand Down Expand Up @@ -288,7 +298,7 @@ func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
}

// paths predicts all paths which start with the given path.
func (p *Predict) paths(path string, includeFiles bool) []string {
func (p *Predict) paths(mountType, mountVersion, path string, includeFiles bool) []string {
client := p.Client()
if client == nil {
return nil
Expand All @@ -303,7 +313,7 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
root = root[:idx+1]
}

paths := p.listPaths(root)
paths := p.listPaths(buildAPIListPath(root, mountType, mountVersion))

var predictions []string
for _, p := range paths {
Expand All @@ -326,6 +336,22 @@ func (p *Predict) paths(path string, includeFiles bool) []string {
return predictions
}

func buildAPIListPath(path, mountType, mountVersion string) string {
if mountType == "kv" && mountVersion == "2" {
return toKVv2ListPath(path)
}
return path
}

func toKVv2ListPath(path string) string {
firstSlashIdx := strings.Index(path, "/")
if firstSlashIdx < 0 {
return path
}

return path[:firstSlashIdx] + "/metadata" + path[firstSlashIdx:]
}

// audits returns a sorted list of the audit backends for Vault server for
// which the client is configured to communicate with.
func (p *Predict) audits() []string {
Expand Down Expand Up @@ -421,16 +447,25 @@ func (p *Predict) policies() []string {
return policies
}

// mounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs.
func (p *Predict) mounts() []string {
func (p *Predict) mountInfos() (map[string]*api.MountOutput, error) {
client := p.Client()
if client == nil {
return nil
return nil, nil
}

mounts, err := client.Sys().ListMounts()
if err != nil {
return nil, err
}

return mounts, nil
}

// mounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs.
func (p *Predict) mounts() []string {
mounts, err := p.mountInfos()
if err != nil {
return defaultPredictVaultMounts
}
Expand Down
75 changes: 74 additions & 1 deletion command/base_predict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,80 @@ func TestPredict_Paths(t *testing.T) {
p := NewPredict()
p.client = client

act := p.paths(tc.path, tc.includeFiles)
act := p.paths("kv", "1", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
})
}
})
}

func TestPredict_PathsKVv2(t *testing.T) {
t.Parallel()

client, closer := testVaultServerWithKVVersion(t, "2")
defer closer()

data := map[string]interface{}{"data": map[string]interface{}{"a": "b"}}
if _, err := client.Logical().Write("secret/data/bar", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/foo", data); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("secret/data/zip/zap", data); err != nil {
t.Fatal(err)
}

cases := []struct {
name string
path string
includeFiles bool
exp []string
}{
{
"bad_path",
"nope/not/a/real/path/ever",
true,
[]string{"nope/not/a/real/path/ever"},
},
{
"good_path",
"secret/",
true,
[]string{"secret/bar", "secret/foo", "secret/zip/"},
},
{
"good_path_no_files",
"secret/",
false,
[]string{"secret/zip/"},
},
{
"partial_match",
"secret/z",
true,
[]string{"secret/zip/"},
},
{
"partial_match_no_files",
"secret/z",
false,
[]string{"secret/zip/"},
},
}

t.Run("group", func(t *testing.T) {
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

p := NewPredict()
p.client = client

act := p.paths("kv", "2", tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
Expand Down
31 changes: 24 additions & 7 deletions command/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ var (
"transit": transit.Factory,
"kv": kv.Factory,
}

defaultTestClusterOptions = &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1, // Default is 3, but we don't need that many
}
)

// assertNoTabs asserts the CLI help has no tab characters.
Expand All @@ -67,6 +72,13 @@ func testVaultServer(tb testing.TB) (*api.Client, func()) {
return client, closer
}

func testVaultServerWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, func()) {
tb.Helper()

client, _, closer := testVaultServerUnsealWithKVVersion(tb, kvVersion)
return client, closer
}

func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
tb.Helper()

Expand All @@ -78,13 +90,17 @@ func testVaultServerAllBackends(tb testing.TB) (*api.Client, func()) {
AuditBackends: auditBackends,
LogicalBackends: logicalBackends,
BuiltinRegistry: builtinplugins.Registry,
})
}, defaultTestClusterOptions)
return client, closer
}

// testVaultServerUnseal creates a test vault cluster and returns a configured
// API client, list of unseal keys (as strings), and a closer function.
func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
return testVaultServerUnsealWithKVVersion(tb, "1")
}

func testVaultServerUnsealWithKVVersion(tb testing.TB, kvVersion string) (*api.Client, []string, func()) {
tb.Helper()
logger := log.NewInterceptLogger(&log.LoggerOptions{
Output: log.DefaultOutput,
Expand All @@ -100,6 +116,10 @@ func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) {
AuditBackends: defaultVaultAuditBackends,
LogicalBackends: defaultVaultLogicalBackends,
BuiltinRegistry: builtinplugins.Registry,
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1,
KVVersion: kvVersion,
})
}

Expand All @@ -118,18 +138,15 @@ func testVaultServerPluginDir(tb testing.TB, pluginDir string) (*api.Client, []s
LogicalBackends: defaultVaultLogicalBackends,
PluginDirectory: pluginDir,
BuiltinRegistry: builtinplugins.Registry,
})
}, defaultTestClusterOptions)
}

// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) {
func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig, opts *vault.TestClusterOptions) (*api.Client, []string, func()) {
tb.Helper()

cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 1, // Default is 3, but we don't need that many
})
cluster := vault.NewTestCluster(benchhelpers.TBtoT(tb), coreConfig, opts)
cluster.Start()

// Make it easy to get access to the active
Expand Down
2 changes: 1 addition & 1 deletion command/kv_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *KVGetCommand) Flags() *FlagSets {
}

func (c *KVGetCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFiles()
}

func (c *KVGetCommand) AutocompleteFlags() complete.Flags {
Expand Down
2 changes: 1 addition & 1 deletion command/kv_patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (c *KVPatchCommand) Flags() *FlagSets {
}

func (c *KVPatchCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFiles()
}

func (c *KVPatchCommand) AutocompleteFlags() complete.Flags {
Expand Down
2 changes: 1 addition & 1 deletion command/kv_put.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *KVPutCommand) Flags() *FlagSets {
}

func (c *KVPutCommand) AutocompleteArgs() complete.Predictor {
return nil
return c.PredictVaultFolders()
}

func (c *KVPutCommand) AutocompleteFlags() complete.Flags {
Expand Down
3 changes: 2 additions & 1 deletion vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ type TestClusterOptions struct {
// this stores the vault version that should be used for each core config
VersionMap map[int]string
RedundancyZoneMap map[int]string
KVVersion string
}

var DefaultNumCores = 3
Expand Down Expand Up @@ -2041,7 +2042,7 @@ func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAudit
"path": "secret/",
"description": "key/value secret storage",
"options": map[string]string{
"version": "1",
"version": opts.KVVersion,
},
},
}
Expand Down

0 comments on commit a916d75

Please sign in to comment.