diff --git a/command.go b/command.go index cdfd360f8..a8b32c6d2 100644 --- a/command.go +++ b/command.go @@ -30,6 +30,8 @@ import ( flag "github.com/spf13/pflag" ) +const FlagSetByCobraAnnotation = "cobra_annotation_flag_set_by_cobra" + // FParseErrWhitelist configures Flag parse errors to be ignored type FParseErrWhitelist flag.ParseErrorsWhitelist @@ -1055,6 +1057,7 @@ func (c *Command) InitDefaultHelpFlag() { usage += c.Name() } c.Flags().BoolP("help", "h", false, usage) + c.Flags().SetAnnotation("help", FlagSetByCobraAnnotation, []string{"true"}) } } @@ -1080,6 +1083,7 @@ func (c *Command) InitDefaultVersionFlag() { } else { c.Flags().Bool("version", false, usage) } + c.Flags().SetAnnotation("version", FlagSetByCobraAnnotation, []string{"true"}) } } diff --git a/completions.go b/completions.go index f89e1700c..1ed6fbe93 100644 --- a/completions.go +++ b/completions.go @@ -407,6 +407,12 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed { foundLocalNonPersistentFlag = true + if (flag.Name == "help" || flag.Name == "version") && + len(flag.Annotations[FlagSetByCobraAnnotation]) > 0 { + // We have the 'help' or 'version' flag and it was set + // by Cobra. We know neither should be followed by anything. + directive = ShellCompDirectiveNoFileComp + } } }) } diff --git a/completions_test.go b/completions_test.go index aa4657e94..b157867b0 100644 --- a/completions_test.go +++ b/completions_test.go @@ -2923,3 +2923,93 @@ func TestCompletionForMutuallyExclusiveFlags(t *testing.T) { }) } } + +func TestCompletionCobraFlags(t *testing.T) { + getCmd := func() *Command { + rootCmd := &Command{ + Use: "root", + Version: "1.1.1", + Run: emptyRun, + } + childCmd := &Command{ + Use: "child", + Run: emptyRun, + } + rootCmd.AddCommand(childCmd) + + return rootCmd + } + + // Each test case uses a unique command from the function above. + testcases := []struct { + desc string + args []string + expectedOutput string + }{ + { + desc: "completion of --help flag", + args: []string{"-"}, + expectedOutput: strings.Join([]string{ + "--help", + "-h", + "--version", + "-v", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "completion of --help flag value", + args: []string{"--help", ""}, + expectedOutput: strings.Join([]string{ + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "completion of -h flag value", + args: []string{"-h", ""}, + expectedOutput: strings.Join([]string{ + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "completion of --version flag", + args: []string{"-"}, + expectedOutput: strings.Join([]string{ + "--help", + "-h", + "--version", + "-v", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "completion of --version flag value", + args: []string{"--version", ""}, + expectedOutput: strings.Join([]string{ + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "completion of -v flag value", + args: []string{"-v", ""}, + expectedOutput: strings.Join([]string{ + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + c := getCmd() + args := []string{ShellCompNoDescRequestCmd} + args = append(args, tc.args...) + output, err := executeCommand(c, args...) + switch { + case err == nil && output != tc.expectedOutput: + t.Errorf("expected: %q, got: %q", tc.expectedOutput, output) + case err != nil: + t.Errorf("Unexpected error %q", err) + } + }) + } +}