From 4ec9d9388ebf5c335bbb4ae7157d3e7244c7c0d3 Mon Sep 17 00:00:00 2001 From: Baruch Odem Date: Tue, 6 Jun 2023 11:41:01 +0300 Subject: [PATCH] feat: add MarkFlagsMutuallyExclusiveAndRequired Fixes #1216 --- flag_groups.go | 51 +++++++++++++++++++++++++++++++++++++++++++-- flag_groups_test.go | 28 ++++++++++++++++++------- 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index b35fde155..ae49a6d06 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,8 +23,9 @@ import ( ) const ( - requiredAsGroup = "cobra_annotation_required_if_others_set" - mutuallyExclusive = "cobra_annotation_mutually_exclusive" + requiredAsGroup = "cobra_annotation_required_if_others_set" + mutuallyExclusive = "cobra_annotation_mutually_exclusive" + mutuallyExclusiveAndRequired = "cobra_annotation_mutually_exclusive_and_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -59,6 +60,22 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } +// MarkFlagsMutuallyExclusiveAndRequired marks the given flags with annotations so that Cobra errors +// if the command is invoked without exactly one flag from the given set of flags. +func (c *Command) MarkFlagsMutuallyExclusiveAndRequired(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAndRequired, append(f.Annotations[mutuallyExclusiveAndRequired], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + // ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the // first error encountered. func (c *Command) ValidateFlagGroups() error { @@ -72,9 +89,11 @@ func (c *Command) ValidateFlagGroups() error { // then a map of each flag name and whether it is set or not. groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + mutuallyExclusiveAndRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAndRequired, mutuallyExclusiveAndRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -83,6 +102,9 @@ func (c *Command) ValidateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateExclusiveAndRequiredFlagGroups(mutuallyExclusiveAndRequiredGroupStatus); err != nil { + return err + } return nil } @@ -163,6 +185,31 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateExclusiveAndRequiredFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + + if len(set) == 0 { + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("exactly one of the flags in the group [%v] must be set; none were set", flagList) + } + if len(set) > 1 { + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("exactly one of the flags in the group [%v] must be set; %v were all set", flagList, set) + } + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 diff --git a/flag_groups_test.go b/flag_groups_test.go index bf988d734..147f0a3b1 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -43,13 +43,14 @@ func TestValidateFlagGroups(t *testing.T) { // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + flagGroupsExclusiveRequires []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsExclusive []string + args []string + expectErr string }{ { desc: "No flags no problem", @@ -67,6 +68,16 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a b c"}, args: []string{"--a=foo", "--b=foo"}, expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "Required exclusive group not satisfied", + flagGroupsExclusiveRequires: []string{"a b c"}, + args: []string{"--d=foo"}, + expectErr: "exactly one of the flags in the group [a b c] must be set; none were set", + }, { + desc: "Required exclusive group selected more than one", + flagGroupsExclusiveRequires: []string{"a b c"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "exactly one of the flags in the group [a b c] must be set; [a b] were all set", }, { desc: "Multiple required flag group not satisfied returns first error", flagGroupsRequired: []string{"a b c", "a d"}, @@ -133,6 +144,9 @@ func TestValidateFlagGroups(t *testing.T) { for _, flagGroup := range tc.flagGroupsExclusive { c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.flagGroupsExclusiveRequires { + c.MarkFlagsMutuallyExclusiveAndRequired(strings.Split(flagGroup, " ")...) + } for _, flagGroup := range tc.subCmdFlagGroupsRequired { sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) }