Skip to content

Commit

Permalink
feat: add markflagsdependenton (spf13#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
plastikfan committed Jul 27, 2022
1 parent 06b06a9 commit 3d2185c
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 8 deletions.
223 changes: 223 additions & 0 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
dependsOn = "cobra_annotation_depends_on"
dependsOnAny = "cobra_annotation_depends_on_any"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -58,6 +60,76 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

// MarkFlagsDependsOn marks the given flags with annotations so that Cobra errors
// if the command is invoked with 1 or more flags that are dependent on a specified
// other.
func (c *Command) MarkFlagsDependsOn(flagNames ...string) {
const format = "Failed to find flag %q and mark it as being part of depends on group"
c.markAnnotation(dependsOn, format, flagNames...)
}

// MarkFlagDependsOnAny marks the given flags with annotations so that Cobra errors
// if the command is invoked with a flag that is dependent on any 1 of a group of others.
func (c *Command) MarkFlagDependsOnAny(flagNames ...string) {
const format = "Failed to find flag %q and mark it as being part of depends on any group"
c.markAnnotation(dependsOnAny, format, flagNames...)
}

// markAnnotation currently only used by MarkFlagsDependsOn and MarkFlagDependsOnAny,
// but is generic enough and should be used by MarkFlagsRequiredTogether and
// MarkFlagsMutuallyExclusive.
// - format must contain a single place holder
func (c *Command) markAnnotation(annotation, format string, flagNames ...string) {
c.mergePersistentFlags()
for _, name := range flagNames {
c.setFlagAnnotation(name, annotation,
fmt.Sprintf(format, name),
flagNames...,
)
}
}

func (c *Command) setFlagAnnotation(flag string, annotation string, message string, flagNames ...string) {
f := c.Flags().Lookup(flag)
if f == nil {
panic(message)
}
ordered := strings.Join(flagNames, " ")
if err := c.Flags().SetAnnotation(
flag, annotation,
append(f.Annotations[annotation], ordered),
); err != nil {
panic(err)
}
}

// The 'special-ness' of a group means that the first member of the group carries
// special meaning. In contrast to the other group types, where all members are equal.
type specialStatusInfo struct {
isSet bool
isSpecial bool
}
type specialStatusInfoData map[string]*specialStatusInfo

type specialGroupInfo struct {
special string
others []string
// maps the flag name to special status info
data specialStatusInfoData
}
type specialGroupInfoCollection map[string]*specialGroupInfo

func newSpecialGroup(specialName string, others []string) *specialGroupInfo {
size := len(others) + 1
result := specialGroupInfo{
special: specialName,
others: others,
data: make(specialStatusInfoData, size),
}

return &result
}

// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) validateFlagGroups() error {
Expand All @@ -71,9 +143,13 @@ func (c *Command) validateFlagGroups() error {
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
dependsOnSpecialGroupStatus := specialGroupInfoCollection{}
dependsOnAnySpecialGroupStatus := specialGroupInfoCollection{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForSpecialGroupAnnotation(flags, pflag, dependsOn, dependsOnSpecialGroupStatus)
processFlagForSpecialGroupAnnotation(flags, pflag, dependsOnAny, dependsOnAnySpecialGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -82,6 +158,12 @@ func (c *Command) validateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateDependsOnFlagGroups(dependsOnSpecialGroupStatus); err != nil {
return err
}
if err := validateDependsOnAnyFlagGroups(dependsOnAnySpecialGroupStatus); err != nil {
return err
}
return nil
}

Expand All @@ -95,6 +177,16 @@ func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
return true
}

func hasAnyOfFlags(fs *flag.FlagSet, flagnames ...string) bool {
for _, fname := range flagnames {
f := fs.Lookup(fname)
if f != nil {
return true
}
}
return false
}

func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
groupInfo, found := pflag.Annotations[annotation]
if found {
Expand All @@ -118,6 +210,52 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota
}
}

func processFlagForSpecialGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag,
annotation string, groupStatus specialGroupInfoCollection) {

if groupInfo, found := pflag.Annotations[annotation]; found {
for _, group := range groupInfo {
if groupStatus[group] == nil {

flagnames := strings.Split(group, " ")
// it's important to know that the order of the flags is established
// in setFlagAnnotation, which makes the assumption of the first
// item being sepcial, being valid
special := flagnames[0]
others := flagnames[1:]
isFlagSpecial := pflag.Name == special

// Only consider this flag group at all if the first flag (Special)
// is set and at least 1 of the others is
if isFlagSpecial && flags.Lookup(special) == nil {
continue
}

if !isFlagSpecial && !hasAnyOfFlags(flags, others...) {
continue
}

groupStatus[group] = newSpecialGroup(special, others)
for _, name := range flagnames {
groupStatus[group].data[name] = &specialStatusInfo{}

if name == special {
groupStatus[group].data[special].isSpecial = true
break // short circuit after finding special
}
}
}

// group exists, but we still need to check if the flag exists in the group,
// because the previous loop is short circuited as soon as we find the special.
if _, found := groupStatus[group].data[pflag.Name]; !found {
groupStatus[group].data[pflag.Name] = &specialStatusInfo{}
}
groupStatus[group].data[pflag.Name].isSet = pflag.Changed
}
}
}

func validateRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
Expand Down Expand Up @@ -162,6 +300,66 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateDependsOnFlagGroups(data specialGroupInfoCollection) error {
keys := sortedKeysSpecial(data)

for _, flagList := range keys {
flagnameAndStatus := data[flagList]

if flagnameAndStatus.data[flagnameAndStatus.special].isSet {
// rule is satisfied, because the special flag is present, regardless of
// the presence of the other members in the group
return nil
}

// we have a problem if at least one of present is set, because special
// is not set
present := []string{}
for _, o := range flagnameAndStatus.others {
if flagnameAndStatus.data[o].isSet {
present = append(present, o)
}
}
if len(present) == 0 {
continue
}
sort.Strings(present)

return fmt.Errorf(
"if any flags in the group %v are set then [%v] must be present; only %v were set",
flagnameAndStatus.others, flagnameAndStatus.special, present,
)
}
return nil
}

func validateDependsOnAnyFlagGroups(data specialGroupInfoCollection) error {
keys := sortedKeysSpecial(data)

for _, flagList := range keys {
flagnameAndStatus := data[flagList]
if !flagnameAndStatus.data[flagnameAndStatus.special].isSet {
return nil
}

present := []string{}
for _, o := range flagnameAndStatus.others {
if flagnameAndStatus.data[o].isSet {
present = append(present, o)
}
}
if len(present) > 0 {
continue
}

return fmt.Errorf(
"if [%v] is present, then at least one of the flags in %v must be; none were set",
flagnameAndStatus.special, flagnameAndStatus.others,
)
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand All @@ -173,6 +371,18 @@ func sortedKeys(m map[string]map[string]bool) []string {
return keys
}

// implemented as a duplicate of sortedKeys as generics can't be used yet
func sortedKeysSpecial(m specialGroupInfoCollection) []string {
keys := make([]string, len(m))
i := 0
for k := range m {
keys[i] = k
i++
}
sort.Strings(keys)
return keys
}

// enforceFlagGroupsForCompletion will do the following:
// - when a flag in a group is present, other flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
Expand All @@ -185,9 +395,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
flags := c.Flags()
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
dependsOnSpecialGroupStatus := specialGroupInfoCollection{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForSpecialGroupAnnotation(flags, pflag, dependsOn, dependsOnSpecialGroupStatus)
})

// If a flag that is part of a group is present, we make all the other flags
Expand Down Expand Up @@ -220,4 +432,15 @@ func (c *Command) enforceFlagGroupsForCompletion() {
}
}
}

// if any of others is set, then mark special as required
for _, flagnameAndStatus := range dependsOnSpecialGroupStatus {
for _, o := range flagnameAndStatus.others {
if flagnameAndStatus.data[o].isSet {
c.MarkFlagRequired(flagnameAndStatus.special)
break
}
}
}
// we can't aid the completion process for dependsOnAny
}

0 comments on commit 3d2185c

Please sign in to comment.