diff --git a/flag_bool.go b/flag_bool.go index b21d5163c9..287da85640 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -1,11 +1,55 @@ package cli import ( + "errors" "flag" "fmt" "strconv" ) +// boolValue needs to implement the boolFlag internal interface in flag +// to be able to capture bool fields and values +// type boolFlag interface { +// Value +// IsBoolFlag() bool +// } +type boolValue struct { + destination *bool + count *int +} + +func newBoolValue(val bool, p *bool, count *int) *boolValue { + *p = val + return &boolValue{ + destination: p, + count: count, + } +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + err = errors.New("parse error") + return err + } + *b.destination = v + if b.count != nil { + *b.count = *b.count + 1 + } + return err +} + +func (b *boolValue) Get() interface{} { return *b.destination } + +func (b *boolValue) String() string { + if b.destination != nil { + return strconv.FormatBool(*b.destination) + } + return strconv.FormatBool(false) +} + +func (b *boolValue) IsBoolFlag() bool { return true } + // TakesValue returns true of the flag takes a value, otherwise false func (f *BoolFlag) TakesValue() bool { return false @@ -56,11 +100,14 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error { } for _, name := range f.Names() { + var value flag.Value if f.Destination != nil { - set.BoolVar(f.Destination, name, f.Value, f.Usage) - continue + value = newBoolValue(f.Value, f.Destination, f.Count) + } else { + t := new(bool) + value = newBoolValue(f.Value, t, f.Count) } - set.Bool(name, f.Value, f.Usage) + set.Var(value, name, f.Usage) } return nil diff --git a/flag_test.go b/flag_test.go index e46b1eff45..c0f98864bd 100644 --- a/flag_test.go +++ b/flag_test.go @@ -62,6 +62,19 @@ func TestBoolFlagValueFromContext(t *testing.T) { expect(t, ff.Get(ctx), false) } +func TestBoolFlagApply_SetsCount(t *testing.T) { + v := false + count := 0 + fl := BoolFlag{Name: "wat", Aliases: []string{"W", "huh"}, Destination: &v, Count: &count} + set := flag.NewFlagSet("test", 0) + _ = fl.Apply(set) + + err := set.Parse([]string{"--wat", "-W", "--huh"}) + expect(t, err, nil) + expect(t, v, true) + expect(t, count, 3) +} + func TestFlagsFromEnv(t *testing.T) { newSetFloat64Slice := func(defaults ...float64) Float64Slice { s := NewFloat64Slice(defaults...)