From 11eeec82e5b36511323db835d57d257e8594b719 Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Tue, 23 Mar 2021 14:11:36 -0400 Subject: [PATCH] Add count option for bool flags --- flag_bool.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++--- flag_test.go | 13 +++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/flag_bool.go b/flag_bool.go index 8bd582094f..3b637d3fe3 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "flag" "fmt" "strconv" @@ -19,8 +20,52 @@ type BoolFlag struct { DefaultText string Destination *bool HasBeenSet bool + Count *int } +// boolValue needs to implement the boolFlag internal interface in flag +// to be able to capture bool fields and values +// type boolFlag interface { +// Value +// IsBoolFlag() bool +// } +type boolValue struct { + destination *bool + count *int +} + +func newBoolValue(val bool, p *bool, count *int) *boolValue { + *p = val + return &boolValue{ + destination: p, + count: count, + } +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + err = errors.New("parse error") + return err + } + *b.destination = v + if b.count != nil { + *b.count = *b.count + 1 + } + return err +} + +func (b *boolValue) Get() interface{} { return *b.destination } + +func (b *boolValue) String() string { + if b.destination != nil { + return strconv.FormatBool(*b.destination) + } + return strconv.FormatBool(false) +} + +func (b *boolValue) IsBoolFlag() bool { return true } + // IsSet returns whether or not the flag has been set through env or file func (f *BoolFlag) IsSet() bool { return f.HasBeenSet @@ -79,11 +124,14 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error { } for _, name := range f.Names() { + var value flag.Value if f.Destination != nil { - set.BoolVar(f.Destination, name, f.Value, f.Usage) - continue + value = newBoolValue(f.Value, f.Destination, f.Count) + } else { + t := new(bool) + value = newBoolValue(f.Value, t, f.Count) } - set.Bool(name, f.Value, f.Usage) + set.Var(value, name, f.Usage) } return nil diff --git a/flag_test.go b/flag_test.go index e46270d034..01ab553d27 100644 --- a/flag_test.go +++ b/flag_test.go @@ -51,6 +51,19 @@ func TestBoolFlagApply_SetsAllNames(t *testing.T) { expect(t, v, true) } +func TestBoolFlagApply_SetsCount(t *testing.T) { + v := false + count := 0 + fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v, Count: &count} + set := flag.NewFlagSet("test", 0) + _ = fl.Apply(set) + + err := set.Parse([]string{"--wat", "-W", "--huh"}) + expect(t, err, nil) + expect(t, v, true) + expect(t, count, 3) +} + func TestFlagsFromEnv(t *testing.T) { newSetIntSlice := func(defaults ...int) IntSlice { s := NewIntSlice(defaults...)