Skip to content

Commit

Permalink
Disable completion after the 'help'/'version' flag
Browse files Browse the repository at this point in the history
If the 'help' and/or 'version' flag are the ones added by Cobra, we
know they should not be followed by anything, we therefore turn off
file completion.

If a program sets its own 'help' or 'version' flag, it can disable
file completion following those flags by using ValidArgsFunction on
the command on which those flags apply.

Signed-off-by: Marc Khouzam <marc.khouzam@gmail.com>
  • Loading branch information
marckhouzam committed Sep 21, 2022
1 parent bfa0766 commit 718c5bc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
4 changes: 4 additions & 0 deletions command.go
Expand Up @@ -30,6 +30,8 @@ import (
flag "github.com/spf13/pflag"
)

const FlagSetByCobraAnnotation = "cobra_annotation_flag_set_by_cobra"

// FParseErrWhitelist configures Flag parse errors to be ignored
type FParseErrWhitelist flag.ParseErrorsWhitelist

Expand Down Expand Up @@ -1055,6 +1057,7 @@ func (c *Command) InitDefaultHelpFlag() {
usage += c.Name()
}
c.Flags().BoolP("help", "h", false, usage)
c.Flags().SetAnnotation("help", FlagSetByCobraAnnotation, []string{"true"})
}
}

Expand All @@ -1080,6 +1083,7 @@ func (c *Command) InitDefaultVersionFlag() {
} else {
c.Flags().Bool("version", false, usage)
}
c.Flags().SetAnnotation("version", FlagSetByCobraAnnotation, []string{"true"})
}
}

Expand Down
6 changes: 6 additions & 0 deletions completions.go
Expand Up @@ -407,6 +407,12 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed {
foundLocalNonPersistentFlag = true
if (flag.Name == "help" || flag.Name == "version") &&
len(flag.Annotations[FlagSetByCobraAnnotation]) > 0 {
// We have the 'help' or 'version' flag and it was set
// by Cobra. We know neither should be followed by anything.
directive = ShellCompDirectiveNoFileComp
}
}
})
}
Expand Down
90 changes: 90 additions & 0 deletions completions_test.go
Expand Up @@ -2923,3 +2923,93 @@ func TestCompletionForMutuallyExclusiveFlags(t *testing.T) {
})
}
}

func TestCompletionCobraFlags(t *testing.T) {
getCmd := func() *Command {
rootCmd := &Command{
Use: "root",
Version: "1.1.1",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
Run: emptyRun,
}
rootCmd.AddCommand(childCmd)

return rootCmd
}

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
args []string
expectedOutput string
}{
{
desc: "completion of --help flag",
args: []string{"-"},
expectedOutput: strings.Join([]string{
"--help",
"-h",
"--version",
"-v",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
{
desc: "completion of --help flag value",
args: []string{"--help", ""},
expectedOutput: strings.Join([]string{
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
{
desc: "completion of -h flag value",
args: []string{"-h", ""},
expectedOutput: strings.Join([]string{
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
{
desc: "completion of --version flag",
args: []string{"-"},
expectedOutput: strings.Join([]string{
"--help",
"-h",
"--version",
"-v",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
{
desc: "completion of --version flag value",
args: []string{"--version", ""},
expectedOutput: strings.Join([]string{
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
{
desc: "completion of -v flag value",
args: []string{"-v", ""},
expectedOutput: strings.Join([]string{
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
},
}

for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
c := getCmd()
args := []string{ShellCompNoDescRequestCmd}
args = append(args, tc.args...)
output, err := executeCommand(c, args...)
switch {
case err == nil && output != tc.expectedOutput:
t.Errorf("expected: %q, got: %q", tc.expectedOutput, output)
case err != nil:
t.Errorf("Unexpected error %q", err)
}
})
}
}

0 comments on commit 718c5bc

Please sign in to comment.