From 3fa3070673ee97ffc6a55fe7fa47e07f079fb542 Mon Sep 17 00:00:00 2001 From: Paul Holzinger Date: Fri, 2 Jul 2021 11:04:49 +0200 Subject: [PATCH] Fix flag completion The flag completion functions should not be stored in the root cmd. There is no requirement that the root cmd should be the same when `RegisterFlagCompletionFunc` was called. Storing the flags there does not work when you add the the flags to your cmd struct before you add the cmd to the parent/root cmd. The flags can no longer be found in the rigth place when the completion command is called and thus the flag completion does not work. Also #1423 claims that this would be thread safe but we still have a map which will fail when accessed concurrently. To truly fix this issue use a RWMutex. Fixes #1437 Fixes #1320 Signed-off-by: Paul Holzinger --- bash_completions.go | 4 +++- command.go | 3 --- completions.go | 22 ++++++++++++++++------ completions_test.go | 19 ++++++++++++++++++- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/bash_completions.go b/bash_completions.go index 925e6e787..733f4d121 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -512,7 +512,9 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) { // Setup annotations for go completions for registered flags func prepareCustomAnnotationsForFlags(cmd *Command) { - for flag := range cmd.Root().flagCompletionFunctions { + flagCompletionMutex.RLock() + defer flagCompletionMutex.RUnlock() + for flag := range flagCompletionFunctions { // Make sure the completion script calls the __*_go_custom_completion function for // every registered flag. We need to do this here (and not when the flag was registered // for completion) so that we can know the root command name for the prefix diff --git a/command.go b/command.go index 9f33a461f..2cc18891d 100644 --- a/command.go +++ b/command.go @@ -142,9 +142,6 @@ type Command struct { // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName - //flagCompletionFunctions is map of flag completion functions. - flagCompletionFunctions map[*flag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) - // usageFunc is usage func defined by user. usageFunc func(*Command) error // usageTemplate is usage template defined by user. diff --git a/completions.go b/completions.go index 4687674aa..a1c12c01e 100644 --- a/completions.go +++ b/completions.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/spf13/pflag" ) @@ -17,6 +18,12 @@ const ( ShellCompNoDescRequestCmd = "__completeNoDesc" ) +// Global map of flag completion functions. Make sure to use flagCompletionMutex before you try to read and write from it. +var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} + +// lock for reading and writing from flagCompletionFunctions +var flagCompletionMutex = &sync.RWMutex{} + // ShellCompDirective is a bit map representing the different behaviors the shell // can be instructed to have once completions have been provided. type ShellCompDirective int @@ -100,15 +107,16 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman if flag == nil { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName) } + flagCompletionMutex.Lock() + defer flagCompletionMutex.Unlock() - root := c.Root() - if _, exists := root.flagCompletionFunctions[flag]; exists { + if _, exists := flagCompletionFunctions[flag]; exists { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName) } - if root.flagCompletionFunctions == nil { - root.flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} + if flagCompletionFunctions == nil { + flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} } - root.flagCompletionFunctions[flag] = f + flagCompletionFunctions[flag] = f return nil } @@ -402,7 +410,9 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Find the completion function for the flag or command var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) if flag != nil && flagCompletion { - completionFn = c.Root().flagCompletionFunctions[flag] + flagCompletionMutex.RLock() + completionFn = flagCompletionFunctions[flag] + flagCompletionMutex.RUnlock() } else { completionFn = finalCmd.ValidArgsFunction } diff --git a/completions_test.go b/completions_test.go index aea06a241..9d8b073b5 100644 --- a/completions_test.go +++ b/completions_test.go @@ -1763,13 +1763,15 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { Run: emptyRun, ValidArgs: []string{"arg1", "arg2"}, } - rootCmd.AddCommand(childCmd, childCmd2) childCmd.Flags().Bool("bool", false, "test bool flag") childCmd.Flags().String("string", "", "test string flag") _ = childCmd.RegisterFlagCompletionFunc("string", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"myval"}, ShellCompDirectiveDefault }) + // Important: only add the commands after RegisterFlagCompletionFunc was called + rootCmd.AddCommand(childCmd, childCmd2) + // Test flag completion with no argument output, err := executeCommand(rootCmd, ShellCompRequestCmd, "child", "--") if err != nil { @@ -1969,6 +1971,21 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { if output != expected { t.Errorf("expected: %q, got: %q", expected, output) } + + // Test that no flag completion works on a subcmd + output, err = executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected = strings.Join([]string{ + "myval", + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } } func TestFlagCompletionInGoWithDesc(t *testing.T) {