From 5f38f45579cabc5b46f8cbe6f3509b7adf439fd1 Mon Sep 17 00:00:00 2001 From: John Schnake Date: Sun, 3 Apr 2022 14:27:40 -0500 Subject: [PATCH] Add ability to mark flags as required or exclusive as a group This change adds two features for dealing with flags: - requiring flags be provided as a group (or not at all) - requiring flags be mutually exclusive of each other By utilizing the flag annotations we can mark which flag groups a flag is a part of and during the parsing process we track which ones we have seen or not. A flag may be a part of multiple groups. The list of flags and the type of group (required together or exclusive) make it a unique group. Signed-off-by: John Schnake --- command.go | 4 ++ flag_groups.go | 144 ++++++++++++++++++++++++++++++++++++++++++++ flag_groups_test.go | 89 +++++++++++++++++++++++++++ go.sum | 8 +-- 4 files changed, 238 insertions(+), 7 deletions(-) create mode 100644 flag_groups.go create mode 100644 flag_groups_test.go diff --git a/command.go b/command.go index ee5365bcb..0f4511f38 100644 --- a/command.go +++ b/command.go @@ -863,6 +863,10 @@ func (c *Command) execute(a []string) (err error) { if err := c.validateRequiredFlags(); err != nil { return err } + if err := c.validateFlagGroups(); err != nil { + return err + } + if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { return err diff --git a/flag_groups.go b/flag_groups.go new file mode 100644 index 000000000..8faad500b --- /dev/null +++ b/flag_groups.go @@ -0,0 +1,144 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "fmt" + "sort" + "strings" + + flag "github.com/spf13/pflag" +) + +const ( + RequiredAsGroup = "cobra_annotation_required_if_others_set" + MutuallyExclusive = "cobra_annotation_mutually_exclusive" +) + +func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f.Annotations == nil { + f.Annotations = map[string][]string{} + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + f.Annotations[RequiredAsGroup] = append(f.Annotations[RequiredAsGroup], strings.Join(flagNames, " ")) + } +} + +func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f.Annotations == nil { + f.Annotations = map[string][]string{} + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + f.Annotations[MutuallyExclusive] = append(f.Annotations[MutuallyExclusive], strings.Join(flagNames, " ")) + } +} + +// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic. +func (c *Command) validateFlagGroups() error { + if c.DisableFlagParsing { + return nil + } + + flags := c.Flags() + + // groupStatus format is the list of flags as a unique ID, + // then a map of each flag name and whether it is set or not. + groupStatus := map[string]map[string]bool{} + mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + flags.VisitAll(func(pflag *flag.Flag) { + processFlagForGroupAnnotation(pflag, RequiredAsGroup, groupStatus) + processFlagForGroupAnnotation(pflag, MutuallyExclusive, mutuallyExclusiveGroupStatus) + }) + + errs := validateRequiredFlagGroups(groupStatus) + errsExclusive := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus) + errs = append(errs, errsExclusive...) + + if len(errs) > 0 { + return combineErrors(errs, `, `) + } + return nil +} + +func processFlagForGroupAnnotation(pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { + groupInfo, found := pflag.Annotations[annotation] + if found { + for _, group := range groupInfo { + if groupStatus[group] == nil { + groupStatus[group] = map[string]bool{} + + flagnames := strings.Split(group, " ") + for _, name := range flagnames { + groupStatus[group][name] = false + } + } + + groupStatus[group][pflag.Name] = pflag.Changed + } + } +} + +func validateRequiredFlagGroups(data map[string]map[string]bool) []error { + var errs []error + for flagList, flagnameAndStatus := range data { + unset := []string{} + for flagname, isSet := range flagnameAndStatus { + if !isSet { + unset = append(unset, flagname) + } + } + if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { + continue + } + sort.Strings(unset) + errs = append(errs, fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)) + } + return errs +} + +func validateExclusiveFlagGroups(data map[string]map[string]bool) []error { + var errs []error + for flagList, flagnameAndStatus := range data { + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) == 0 || len(set) == 1 { + continue + } + sort.Strings(set) + errs = append(errs, fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)) + } + return errs +} + +// combineErrors will squash the errors together into a single error with there messages +// joined by the given seperator. +func combineErrors(errs []error, sep string) error { + if len(errs) == 0 { + return nil + } + var msgs []string + for _, e := range errs { + msgs = append(msgs, e.Error()) + } + sort.Strings(msgs) + return fmt.Errorf(strings.Join(msgs, sep)) +} diff --git a/flag_groups_test.go b/flag_groups_test.go new file mode 100644 index 000000000..30508d948 --- /dev/null +++ b/flag_groups_test.go @@ -0,0 +1,89 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "strings" + "testing" +) + +func TestValidateFlagGroups(t *testing.T) { + getCmd := func() *Command { + c := &Command{ + Use: "testcmd", + Run: func(cmd *Command, args []string) { + }} + // Define lots of flags to utilize for testing. + for _, v := range []string{"a", "b", "c", "d", "e", "f", "g"} { + c.Flags().String(v, "", "") + } + return c + } + + // Each test case uses a unique command from the function above. + testcases := []struct { + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + args []string + expectErr string + }{ + { + desc: "No flags no problem", + }, { + desc: "No flags no problem even with conflicting groups", + flagGroupsRequired: []string{"a b"}, + flagGroupsExclusive: []string{"a b"}, + }, { + desc: "Required flag group not satisfied", + flagGroupsRequired: []string{"a b c"}, + args: []string{"--a=foo"}, + expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + }, { + desc: "Exclusive flag group not satisfied", + flagGroupsExclusive: []string{"a b c"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "Multiple required flag group not satisfied", + flagGroupsRequired: []string{"a b c", "a d"}, + args: []string{"--c=foo", "--d=foo"}, + expectErr: "if any flags in the group [a b c] are set they must all be set; missing [a b], if any flags in the group [a d] are set they must all be set; missing [a]", + }, { + desc: "Multiple exclusive flag group not satisfied", + flagGroupsExclusive: []string{"a b c", "a d"}, + args: []string{"testcmd", "--a=foo", "--c=foo", "--d=foo"}, + expectErr: "if any flags in the group [a b c] are set none of the others can be; [a c] were all set, if any flags in the group [a d] are set none of the others can be; [a d] were all set", + }, + } + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + c := getCmd() + for _, flagGroup := range tc.flagGroupsRequired { + c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.flagGroupsExclusive { + c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + } + c.SetArgs(tc.args) + err := c.Execute() + switch { + case err == nil && len(tc.expectErr) > 0: + t.Errorf("Expected error %q but got nil", tc.expectErr) + case err != nil && err.Error() != tc.expectErr: + t.Errorf("Expected error %q but got %q", tc.expectErr, err) + } + }) + } +} diff --git a/go.sum b/go.sum index 431058ed0..6d5345968 100644 --- a/go.sum +++ b/go.sum @@ -2,17 +2,11 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKY github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=