diff --git a/README.md b/README.md index 372db71..697cf80 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,20 @@ var CLI struct { For flags, multiple key+value pairs should be separated by `mapsep:"rune"` tag (defaults to `;`) eg. `--set="key1=value1;key2=value2"`. +## Pointers + +Pointers work like the underlying type, except that you can differentiate between the presence of the zero value and no value being supplied. + +For example: + +```go +var CLI struct { + Foo *int +} +``` + +Would produce a nil value for `Foo` if no `--foo` argument is supplied, but would have a pointer to the value 0 if the argument `--foo=0` was supplied. + ## Nested data structure Kong support a nested data structure as well with `embed:""`. You can combine `embed:""` with `prefix:""`: @@ -628,4 +642,4 @@ See the [section on hooks](#hooks-beforeresolve-beforeapply-afterapply-and-the-b ### Other options -The full set of options can be found [here](https://godoc.org/github.com/alecthomas/kong#Option). +The full set of options can be found [here](https://godoc.org/github.com/alecthomas/kong#Option). \ No newline at end of file diff --git a/context.go b/context.go index 45c412a..7097917 100644 --- a/context.go +++ b/context.go @@ -661,6 +661,22 @@ func (c *Context) Apply() (string, error) { return strings.Join(path, " "), nil } +func flipBoolValue(value reflect.Value) error { + if value.Kind() == reflect.Bool { + value.SetBool(!value.Bool()) + return nil + } + + if value.Kind() == reflect.Ptr { + if !value.IsNil() { + return flipBoolValue(value.Elem()) + } + return nil + } + + return fmt.Errorf("cannot negate a value of %s", value.Type().String()) +} + func (c *Context) parseFlag(flags []*Flag, match string) (err error) { candidates := []string{} for _, flag := range flags { @@ -689,7 +705,10 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) { } if flag.Negated { value := c.getValue(flag.Value) - value.SetBool(!value.Bool()) + err := flipBoolValue(value) + if err != nil { + return err + } flag.Value.Apply(value) } c.Path = append(c.Path, &Path{Flag: flag}) @@ -889,6 +908,11 @@ func checkEnum(value *Value, target reflect.Value) error { case reflect.Map, reflect.Struct: return errors.New("enum can only be applied to a slice or value") + case reflect.Ptr: + if target.IsNil() { + return nil + } + return checkEnum(value, target.Elem()) default: enumSlice := value.EnumSlice() v := fmt.Sprintf("%v", target) diff --git a/kong_test.go b/kong_test.go index 6ce73f5..0705248 100644 --- a/kong_test.go +++ b/kong_test.go @@ -1542,3 +1542,234 @@ func TestPassthroughCmdOnlyStringArgs(t *testing.T) { _, err := kong.New(&cli) require.EqualError(t, err, ".Command: passthrough command command [ ...] must contain exactly one positional argument of []string type") } + +func TestStringPointer(t *testing.T) { + var cli struct { + Foo *string + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--foo", "wtf"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.Foo) + require.Equal(t, "wtf", *cli.Foo) +} + +func TestStringPointerNoValue(t *testing.T) { + var cli struct { + Foo *string + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.Nil(t, cli.Foo) +} + +func TestStringPointerDefault(t *testing.T) { + var cli struct { + Foo *string `default:"stuff"` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.Foo) + require.Equal(t, "stuff", *cli.Foo) +} + +func TestStringPointerAliasNoValue(t *testing.T) { + type Foo string + var cli struct { + F *Foo + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.Nil(t, cli.F) +} + +func TestStringPointerAlias(t *testing.T) { + type Foo string + var cli struct { + F *Foo + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--f=value"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.F) + require.Equal(t, Foo("value"), *cli.F) +} + +func TestStringPointerEmptyValue(t *testing.T) { + var cli struct { + F *string + G *string + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--f", "", "--g="}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.F) + require.NotNil(t, cli.G) + require.Equal(t, "", *cli.F) + require.Equal(t, "", *cli.G) +} + +func TestIntPtr(t *testing.T) { + var cli struct { + F *int + G *int + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--f=6"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.F) + require.Nil(t, cli.G) + require.Equal(t, 6, *cli.F) +} + +func TestBoolPtr(t *testing.T) { + var cli struct { + X *bool + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--x"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.X) + require.Equal(t, true, *cli.X) +} + +func TestBoolPtrFalse(t *testing.T) { + var cli struct { + X *bool + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--x=false"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.X) + require.Equal(t, false, *cli.X) +} + +func TestBoolPtrNegated(t *testing.T) { + var cli struct { + X *bool `negatable:""` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--no-x"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.X) + require.Equal(t, false, *cli.X) +} + +func TestNilNegatableBoolPtr(t *testing.T) { + var cli struct { + X *bool `negatable:""` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.Nil(t, cli.X) +} + +func TestBoolPtrNil(t *testing.T) { + var cli struct { + X *bool + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.Nil(t, cli.X) +} + +func TestUnsupportedPtr(t *testing.T) { + //nolint:structcheck,unused + type Foo struct { + x int + y int + } + + var cli struct { + F *Foo + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--f=whatever"}) + require.Nil(t, ctx) + require.Error(t, err) + require.Equal(t, "--f: cannot find mapper for kong_test.Foo", err.Error()) +} + +func TestEnumPtr(t *testing.T) { + var cli struct { + X *string `enum:"A,B,C" default:"C"` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{"--x=A"}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.X) + require.Equal(t, "A", *cli.X) +} + +func TestEnumPtrOmitted(t *testing.T) { + var cli struct { + X *string `enum:"A,B,C" default:"C"` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.NotNil(t, cli.X) + require.Equal(t, "C", *cli.X) +} + +func TestEnumPtrOmittedNoDefault(t *testing.T) { + var cli struct { + X *string `enum:"A,B,C"` + } + k, err := kong.New(&cli) + require.NoError(t, err) + require.NotNil(t, k) + ctx, err := k.Parse([]string{}) + require.NoError(t, err) + require.NotNil(t, ctx) + require.Nil(t, cli.X) +} diff --git a/mapper.go b/mapper.go index e8778ee..ef538f0 100644 --- a/mapper.go +++ b/mapper.go @@ -274,7 +274,8 @@ func (r *Registry) RegisterDefaults() *Registry { RegisterName("path", pathMapper(r)). RegisterName("existingfile", existingFileMapper(r)). RegisterName("existingdir", existingDirMapper(r)). - RegisterName("counter", counterMapper()) + RegisterName("counter", counterMapper()). + RegisterKind(reflect.Ptr, ptrMapper(r)) } type boolMapper struct{} @@ -653,6 +654,22 @@ func existingDirMapper(r *Registry) MapperFunc { } } +func ptrMapper(r *Registry) MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + elem := reflect.New(target.Type().Elem()).Elem() + nestedMapper := r.ForValue(elem) + if nestedMapper == nil { + return fmt.Errorf("cannot find mapper for %v", target.Type().Elem().String()) + } + err := nestedMapper.Decode(ctx, elem) + if err != nil { + return err + } + target.Set(elem.Addr()) + return nil + } +} + func counterMapper() MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { if ctx.Scan.Peek().Type == FlagValueToken { diff --git a/tag.go b/tag.go index 8e159dd..b471613 100644 --- a/tag.go +++ b/tag.go @@ -169,9 +169,11 @@ func parseTag(parent reflect.Value, ft reflect.StructField) (*Tag, error) { func hydrateTag(t *Tag, typ reflect.Type) error { // nolint: gocyclo var typeName string var isBool bool + var isBoolPtr bool if typ != nil { typeName = typ.Name() isBool = typ.Kind() == reflect.Bool + isBoolPtr = typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Bool } var err error t.Cmd = t.Has("cmd") @@ -212,7 +214,7 @@ func hydrateTag(t *Tag, typ reflect.Type) error { // nolint: gocyclo t.EnvPrefix = t.Get("envprefix") t.Embed = t.Has("embed") negatable := t.Has("negatable") - if negatable && !isBool { + if negatable && !isBool && !isBoolPtr { return fmt.Errorf("negatable can only be set on booleans") } t.Negatable = negatable @@ -230,7 +232,7 @@ func hydrateTag(t *Tag, typ reflect.Type) error { // nolint: gocyclo } t.PlaceHolder = t.Get("placeholder") t.Enum = t.Get("enum") - scalarType := (typ == nil || !(typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map)) + scalarType := (typ == nil || !(typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map || typ.Kind() == reflect.Ptr)) if t.Enum != "" && !(t.Required || t.HasDefault) && scalarType { return fmt.Errorf("enum value is only valid if it is either required or has a valid default value") }