diff --git a/bash_completions.go b/bash_completions.go index ab428ccb8..2ac9d6cc0 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -495,12 +495,15 @@ func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) { func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) { name := flag.Name - format := " local_nonpersistent_flags+=(\"--%s" + format := " local_nonpersistent_flags+=(\"--%[1]s" if len(flag.NoOptDefVal) == 0 { - format += "=" + format += "\")\n local_nonpersistent_flags+=(\"--%[1]s=" } format += "\")\n" buf.WriteString(fmt.Sprintf(format, name)) + if len(flag.Shorthand) > 0 { + buf.WriteString(fmt.Sprintf(" local_nonpersistent_flags+=(\"-%s\")\n", flag.Shorthand)) + } } // Setup annotations for go completions for registered flags @@ -535,7 +538,9 @@ func writeFlags(buf *bytes.Buffer, cmd *Command) { if len(flag.Shorthand) > 0 { writeShortFlag(buf, flag, cmd) } - if localNonPersistentFlags.Lookup(flag.Name) != nil { + // localNonPersistentFlags are used to stop the completion of subcommands when one is set + // if TraverseChildren is true we should allow to complete subcommands + if localNonPersistentFlags.Lookup(flag.Name) != nil && !cmd.Root().TraverseChildren { writeLocalNonPersistentFlag(buf, flag) } }) diff --git a/bash_completions_test.go b/bash_completions_test.go index eefa3de07..2c182ba73 100644 --- a/bash_completions_test.go +++ b/bash_completions_test.go @@ -193,6 +193,13 @@ func TestBashCompletions(t *testing.T) { checkOmit(t, output, `two_word_flags+=("--two-w-default")`) checkOmit(t, output, `two_word_flags+=("-T")`) + // check local nonpersistent flag + check(t, output, `local_nonpersistent_flags+=("--two")`) + check(t, output, `local_nonpersistent_flags+=("--two=")`) + check(t, output, `local_nonpersistent_flags+=("-t")`) + check(t, output, `local_nonpersistent_flags+=("--two-w-default")`) + check(t, output, `local_nonpersistent_flags+=("-T")`) + checkOmit(t, output, deprecatedCmd.Name()) // If available, run shellcheck against the script. @@ -235,3 +242,21 @@ func TestBashCompletionDeprecatedFlag(t *testing.T) { t.Errorf("expected completion to not include %q flag: Got %v", flagName, output) } } + +func TestBashCompletionTraverseChildren(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun, TraverseChildren: true} + + c.Flags().StringP("string-flag", "s", "", "string flag") + c.Flags().BoolP("bool-flag", "b", false, "bool flag") + + buf := new(bytes.Buffer) + c.GenBashCompletion(buf) + output := buf.String() + + // check that local nonpersistent flag are not set since we have TraverseChildren set to true + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag=")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-s")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--bool-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-b")`) +}