From a2600feb72643f8e13b185465d8fcaa7e0408bfc Mon Sep 17 00:00:00 2001 From: John Schnake Date: Wed, 13 Apr 2022 08:27:26 -0500 Subject: [PATCH] Updates for the persistent/local mixing problem --- flag_groups.go | 25 ++++++++++++++++----- flag_groups_test.go | 54 ++++++++++++++++++++++++++++++++++++--------- user_guide.md | 1 + 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 1e9d45691..1dba424ad 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -72,8 +72,8 @@ func (c *Command) validateFlagGroups() error { groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(pflag, requiredAsGroup, groupStatus) - processFlagForGroupAnnotation(pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -85,14 +85,29 @@ func (c *Command) validateFlagGroups() error { return nil } -func processFlagForGroupAnnotation(pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { +func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { + for _, fname := range flagnames { + f := fs.Lookup(fname) + if f == nil { + return false + } + } + return true +} + +func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { groupInfo, found := pflag.Annotations[annotation] if found { for _, group := range groupInfo { if groupStatus[group] == nil { - groupStatus[group] = map[string]bool{} - flagnames := strings.Split(group, " ") + + // Only consider this flag group at all if all the flags are defined. + if !hasAllFlags(flags, flagnames...) { + continue + } + + groupStatus[group] = map[string]bool{} for _, name := range flagnames { groupStatus[group][name] = false } diff --git a/flag_groups_test.go b/flag_groups_test.go index d4053dd0b..404ede562 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -31,16 +31,24 @@ func TestValidateFlagGroups(t *testing.T) { for _, v := range []string{"e", "f", "g"} { c.PersistentFlags().String(v, "", "") } + subC := &Command{ + Use: "subcmd", + Run: func(cmd *Command, args []string) { + }} + subC.Flags().String("subonly", "", "") + c.AddCommand(subC) return c } // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsExclusive []string + args []string + expectErr string }{ { desc: "No flags no problem", @@ -66,46 +74,70 @@ func TestValidateFlagGroups(t *testing.T) { }, { desc: "Multiple exclusive flag group not satisfied returns first error", flagGroupsExclusive: []string{"a b c", "a d"}, - args: []string{"testcmd", "--a=foo", "--c=foo", "--d=foo"}, + args: []string{"--a=foo", "--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, }, { desc: "Validation of required groups occurs on groups in sorted order", flagGroupsRequired: []string{"a d", "a b", "a c"}, - args: []string{"testcmd", "--a=foo"}, + args: []string{"--a=foo"}, expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, }, { desc: "Validation of exclusive groups occurs on groups in sorted order", flagGroupsExclusive: []string{"a d", "a b", "a c"}, - args: []string{"testcmd", "--a=foo", "--b=foo", "--c=foo"}, + args: []string{"--a=foo", "--b=foo", "--c=foo"}, expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, }, { desc: "Persistent flags utilize both features and can fail required groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, - args: []string{"testcmd", "--a=foo", "--f=foo", "--g=foo"}, + args: []string{"--a=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, }, { desc: "Persistent flags utilize both features and can fail mutually exclusive groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, - args: []string{"testcmd", "--a=foo", "--e=foo", "--f=foo", "--g=foo"}, + args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, }, { desc: "Persistent flags utilize both features and can pass", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, - args: []string{"testcmd", "--a=foo", "--e=foo", "--f=foo"}, + args: []string{"--a=foo", "--e=foo", "--f=foo"}, + }, { + desc: "Subcmds can use required groups using inherited flags", + subCmdFlagGroupsRequired: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + }, { + desc: "Subcmds can use exclusive groups using inherited flags", + subCmdFlagGroupsExclusive: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + expectErr: "if any flags in the group [e subonly] are set none of the others can be; [e subonly] were all set", + }, { + desc: "Subcmds can use exclusive groups using inherited flags and pass", + subCmdFlagGroupsExclusive: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo"}, + }, { + desc: "Flag groups not applied if not found on invoked command", + subCmdFlagGroupsRequired: []string{"e subonly"}, + args: []string{"--e=foo"}, }, } for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { c := getCmd() + sub := c.Commands()[0] for _, flagGroup := range tc.flagGroupsRequired { c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) } for _, flagGroup := range tc.flagGroupsExclusive { c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.subCmdFlagGroupsRequired { + sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsExclusive { + sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + } c.SetArgs(tc.args) err := c.Execute() switch { diff --git a/user_guide.md b/user_guide.md index e9b8a4eed..56a1e9c60 100644 --- a/user_guide.md +++ b/user_guide.md @@ -320,6 +320,7 @@ rootCmd.MarkFlagsMutuallyExclusive("json", "yaml") In both of these cases: - both local and persistent flags can be used + - **NOTE:** the group is only enforced on commands where every flag is defined - a flag may appear in multiple groups - a group may contain any number of flags