From 1abf95b1f07bda449f356eb1941851a93f0b2ea8 Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Tue, 23 Mar 2021 14:11:36 -0400 Subject: [PATCH] Add count option for bool flags --- flag_bool.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++--- flag_test.go | 12 ++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/flag_bool.go b/flag_bool.go index bc9ea35d08..4a6b1f3e6e 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "flag" "fmt" "strconv" @@ -19,8 +20,49 @@ type BoolFlag struct { DefaultText string Destination *bool HasBeenSet bool + Count int } +// boolValue needs to implement the boolFlag internal interface in flag +// to be able to capture bool fields and values +// type boolFlag interface { +// Value +// IsBoolFlag() bool +// } +type boolValue struct { + destination *bool + count *int +} + +func newBoolValue(val bool, p *bool, count *int) *boolValue { + *p = val + return &boolValue{ + destination: p, + count: count, + } +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + err = errors.New("parse error") + } + *b.destination = v + *b.count = *b.count + 1 + return err +} + +func (b *boolValue) Get() interface{} { return *b.destination } + +func (b *boolValue) String() string { + if b.destination != nil { + return strconv.FormatBool(*b.destination) + } + return strconv.FormatBool(false) +} + +func (b *boolValue) IsBoolFlag() bool { return true } + // IsSet returns whether or not the flag has been set through env or file func (f *BoolFlag) IsSet() bool { return f.HasBeenSet @@ -74,11 +116,13 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error { } for _, name := range f.Names() { + var value flag.Value if f.Destination != nil { - set.BoolVar(f.Destination, name, f.Value, f.Usage) - continue + value = newBoolValue(f.Value, f.Destination, &f.Count) + } else { + value = newBoolValue(f.Value, &f.Value, &f.Count) } - set.Bool(name, f.Value, f.Usage) + set.Var(value, name, f.Usage) } return nil diff --git a/flag_test.go b/flag_test.go index b3b0d7c587..217befd4a0 100644 --- a/flag_test.go +++ b/flag_test.go @@ -51,6 +51,18 @@ func TestBoolFlagApply_SetsAllNames(t *testing.T) { expect(t, v, true) } +func TestBoolFlagApply_SetsCount(t *testing.T) { + v := false + fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v} + set := flag.NewFlagSet("test", 0) + _ = fl.Apply(set) + + err := set.Parse([]string{"--wat", "-W", "--huh"}) + expect(t, err, nil) + expect(t, v, true) + expect(t, fl.Count, 3) +} + func TestFlagsFromEnv(t *testing.T) { newSetIntSlice := func(defaults ...int) IntSlice { s := NewIntSlice(defaults...)