Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let the commands store flagComp functions internally (and avoid global state) #2012

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
8 changes: 5 additions & 3 deletions bash_completions.go
Expand Up @@ -534,9 +534,10 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) {

// prepareCustomAnnotationsForFlags setup annotations for go completions for registered flags
func prepareCustomAnnotationsForFlags(cmd *Command) {
flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()
for flag := range flagCompletionFunctions {
cmd.initializeCompletionStorage()
cmd.flagCompletionMutex.RLock()
defer cmd.flagCompletionMutex.RUnlock()
for flag := range cmd.flagCompletionFunctions {
// Make sure the completion script calls the __*_go_custom_completion function for
// every registered flag. We need to do this here (and not when the flag was registered
// for completion) so that we can know the root command name for the prefix
Expand Down Expand Up @@ -644,6 +645,7 @@ func writeCmdAliases(buf io.StringWriter, cmd *Command) {
WriteStringAndCheck(buf, ` fi`)
WriteStringAndCheck(buf, "\n")
}

func writeArgAliases(buf io.StringWriter, cmd *Command) {
WriteStringAndCheck(buf, " noun_aliases=()\n")
sort.Strings(cmd.ArgAliases)
Expand Down
8 changes: 8 additions & 0 deletions command.go
Expand Up @@ -26,6 +26,7 @@ import (
"path/filepath"
"sort"
"strings"
"sync"

flag "github.com/spf13/pflag"
)
Expand Down Expand Up @@ -163,6 +164,13 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName

// flagsCompletions contrains completions for arbitrary lists of flags.
// Those flags may or may not actually strictly belong to the command in the function,
// but registering completions for them through the command allows for garbage-collecting.
flagCompletionFunctions map[*flag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
maxlandon marked this conversation as resolved.
Show resolved Hide resolved
// lock for reading and writing from flagCompletionFunctions
flagCompletionMutex *sync.RWMutex
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
Expand Down
86 changes: 63 additions & 23 deletions completions.go
Expand Up @@ -32,12 +32,6 @@ const (
ShellCompNoDescRequestCmd = "__completeNoDesc"
)

// Global map of flag completion functions. Make sure to use flagCompletionMutex before you try to read and write from it.
var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}

// lock for reading and writing from flagCompletionFunctions
var flagCompletionMutex = &sync.RWMutex{}

// ShellCompDirective is a bit map representing the different behaviors the shell
// can be instructed to have once completions have been provided.
type ShellCompDirective int
Expand Down Expand Up @@ -135,28 +129,77 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
if flag == nil {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
}
flagCompletionMutex.Lock()
defer flagCompletionMutex.Unlock()
// Ensure none of our relevant fields are nil.
c.initializeCompletionStorage()
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

c.flagCompletionMutex.Lock()
defer c.flagCompletionMutex.Unlock()

if _, exists := flagCompletionFunctions[flag]; exists {
// And attempt to bind the completion.
if _, exists := c.flagCompletionFunctions[flag]; exists {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
}
flagCompletionFunctions[flag] = f
c.flagCompletionFunctions[flag] = f
return nil
}

// GetFlagCompletionFunc returns the completion function for the given flag of the command, if available.
func (c *Command) GetFlagCompletionFunc(flagName string) (func(*Command, []string, string) ([]string, ShellCompDirective), bool) {
flag := c.Flag(flagName)
if flag == nil {
// GetFlagCompletion returns the completion function for the given flag, if available.
func (c *Command) GetFlagCompletionFunc(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
c.initializeCompletionStorage()
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

c.flagCompletionMutex.RLock()
defer c.flagCompletionMutex.RUnlock()

completionFunc, exists := c.flagCompletionFunctions[flag]

// If found it here, return now
if completionFunc != nil && exists {
return completionFunc, exists
}

// If we are already at the root command level, return anyway
if !c.HasParent() {
return nil, false
}

flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()
// Or walk up the command tree.
return c.Parent().GetFlagCompletionFunc(flag)
}

completionFunc, exists := flagCompletionFunctions[flag]
return completionFunc, exists
// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available.
// If the flag is not found in the command's local flags, it looks into the persistent flags, which might belong to one of the command's parents.
func (c *Command) GetFlagCompletionFuncByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
// Attempt to find it in the local flags.
if flag := c.Flags().Lookup(flagName); flag != nil {
return c.GetFlagCompletionFunc(flag)
}

// Or try to find it in the "command-specific" persistent flags.
if flag := c.PersistentFlags().Lookup(flagName); flag != nil {
return c.GetFlagCompletionFunc(flag)
}

// Else, check all persistent flags belonging to one of the parents.
// This ensures that we won't return the completion function of a
// parent's LOCAL flag.
if flag := c.InheritedFlags().Lookup(flagName); flag != nil {
return c.GetFlagCompletionFunc(flag)
}

// No flag exists either locally, or as one of the parent persistent flags.
return nil, false
}

// initializeCompletionStorage is (and should be) called in all
// functions that make use of the command's flag completion functions.
func (c *Command) initializeCompletionStorage() {
maxlandon marked this conversation as resolved.
Show resolved Hide resolved
if c.flagCompletionMutex == nil {
c.flagCompletionMutex = new(sync.RWMutex)
}

if c.flagCompletionFunctions == nil {
c.flagCompletionFunctions = make(map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), 0)
}
}

// Returns a string listing the different directive enabled in the specified parameter
Expand Down Expand Up @@ -502,9 +545,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Find the completion function for the flag or command
var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
if flag != nil && flagCompletion {
flagCompletionMutex.RLock()
completionFn = flagCompletionFunctions[flag]
flagCompletionMutex.RUnlock()
completionFn, _ = finalCmd.GetFlagCompletionFunc(flag)
} else {
completionFn = finalCmd.ValidArgsFunction
}
Expand Down Expand Up @@ -828,7 +869,6 @@ to your powershell profile.
return cmd.Root().GenPowerShellCompletion(out)
}
return cmd.Root().GenPowerShellCompletionWithDesc(out)

},
}
if haveNoDescFlag {
Expand Down Expand Up @@ -868,7 +908,7 @@ func CompDebug(msg string, printToStdErr bool) {
// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
if path := os.Getenv("BASH_COMP_DEBUG_FILE"); path != "" {
f, err := os.OpenFile(path,
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err == nil {
defer f.Close()
WriteStringAndCheck(f, msg)
Expand Down