From 3915e7492d26337603ceca2f011ca9bd2c79d2b0 Mon Sep 17 00:00:00 2001 From: David Muir Sharnoff Date: Tue, 28 Dec 2021 18:16:16 -0800 Subject: [PATCH] time.Duration & flag.Value --- setters.go | 46 ++++++++++++ unpack.go | 33 +++++++++ unpack_test.go | 188 ++++++++++++++++++++++++++++--------------------- 3 files changed, 186 insertions(+), 81 deletions(-) create mode 100644 setters.go diff --git a/setters.go b/setters.go new file mode 100644 index 0000000..e58087e --- /dev/null +++ b/setters.go @@ -0,0 +1,46 @@ +package reflectutils + +import ( + "reflect" + "time" +) + +func init() { + RegisterStringSetter(time.ParseDuration) +} + +var settersByType = make(map[reflect.Type]reflect.Value) + +// RegisterStringSetter registers functions that can be used to transform +// strings into specific types. The fn argument must be a function that +// takes a string and returns an arbitrary type and an error. An example +// of such a function is time.ParseDuration. Any call to RegisterStringSetter +// with a value that is not a function of that sort will panic. +// +// RegisterStringSetter is not thread safe and should probably only be +// used during init(). +// +// These functions are used by MakeStringSetter() when there is an opportunity +// to do so. +func RegisterStringSetter(fn interface{}) { + v := reflect.ValueOf(fn) + if !v.IsValid() { + panic("call to RegisterStringSetter with an invalid value") + } + if v.Type().Kind() != reflect.Func { + panic("call to RegisterStringSetter with something other than a function") + } + if v.Type().NumIn() != 1 { + panic("call to RegisterStringSetter with something other than a function that takes one arg") + } + if v.Type().NumOut() != 2 { + panic("call to RegisterStringSetter with something other than a function that takes returns two values") + } + if v.Type().In(0) != reflect.TypeOf((*string)(nil)).Elem() { + panic("call to RegisterStringSetter with something other than a function that takes something other than string") + } + if v.Type().Out(1) != reflect.TypeOf((*error)(nil)).Elem() { + panic("call to RegisterStringSetter with something other than a function that returns something other than error") + } + settersByType[v.Type().Out(0)] = v +} diff --git a/unpack.go b/unpack.go index b85b58d..97a073b 100644 --- a/unpack.go +++ b/unpack.go @@ -2,6 +2,7 @@ package reflectutils import ( "encoding" + "flag" "reflect" "strconv" "strings" @@ -10,6 +11,7 @@ import ( ) var textUnmarshallerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +var flagValueType = reflect.TypeOf((*flag.Value)(nil)).Elem() type stringSetterOpts struct { split string @@ -46,7 +48,11 @@ func SliceAppend(b bool) StringSetterArg { // For arrays and slices, strings are split on comma to create the values for the // elements. // +// Any type that matches a type registered with RegisterStringSetter will be +// unpacked with the corresponding function. A string setter is pre-registered +// for time.Duration // Anything that implements encoding.TextUnmarshaler will be unpacked that way. +// Anything that implements flag.Value will be unpacked that way. // // Maps, structs, channels, interfaces, channels, and funcs are not supported unless // they happen to implent encoding.TextUnmarshaler. @@ -58,6 +64,16 @@ func MakeStringSetter(t reflect.Type, optArgs ...StringSetterArg) (func(target r for _, f := range optArgs { f(&opts) } + if setter, ok := settersByType[t]; ok { + return func(target reflect.Value, value string) error { + out := setter.Call([]reflect.Value{reflect.ValueOf(value)}) + if !out[1].IsNil() { + return out[1].Interface().(error) + } + target.Set(out[0]) + return nil + }, nil + } if t.AssignableTo(textUnmarshallerType) { return func(target reflect.Value, value string) error { p := reflect.New(t.Elem()) @@ -75,6 +91,23 @@ func MakeStringSetter(t reflect.Type, optArgs ...StringSetterArg) (func(target r return errors.WithStack(err) }, nil } + if t.AssignableTo(flagValueType) { + return func(target reflect.Value, value string) error { + p := reflect.New(t.Elem()) + target.Set(p) + err := target.Interface().(flag.Value).Set(value) + if err != nil { + return errors.WithStack(err) + } + return nil + }, nil + } + if reflect.PtrTo(t).AssignableTo(flagValueType) { + return func(target reflect.Value, value string) error { + err := target.Addr().Interface().(flag.Value).Set(value) + return errors.WithStack(err) + }, nil + } switch t.Kind() { case reflect.Ptr: setElem, err := MakeStringSetter(t.Elem()) diff --git a/unpack_test.go b/unpack_test.go index 0587862..abbf454 100644 --- a/unpack_test.go +++ b/unpack_test.go @@ -1,10 +1,13 @@ package reflectutils_test import ( + "encoding" + "flag" "fmt" "reflect" "strconv" "testing" + "time" "github.com/muir/reflectutils" @@ -19,98 +22,121 @@ func (fp *Foo) UnmarshalText(b []byte) error { return nil } +var _ encoding.TextUnmarshaler = func() *Foo { var x Foo; return &x }() + +type Bar string + +func (bp *Bar) Set(s string) error { + *bp = Bar(s + "/e") + return nil +} +func (bp Bar) String() string { + return "b/" + string(bp) +} + +var _ flag.Value = func() *Bar { var x Bar; return &x }() + func TestStringSetter(t *testing.T) { type tsType struct { - Int int `value:"38"` - Int8 int8 `value:"-9"` - Int16 int16 `value:"329"` - Int32 int32 `value:"-32902"` - Int64 int64 `value:"3292929"` - Uint uint `value:"202"` - Uint8 uint8 `value:"99"` - Uint16 uint16 `value:"3020"` - Uint32 uint32 `value:"92020"` - Uint64 uint64 `value:"320202"` - Float32 float32 `value:"3.9"` - Float64 float64 `value:"4.32e7" want:"4.32e+07"` - String string `value:"foo"` - Bool bool `value:"false"` - IntP *int `value:"-82"` - Int8P *int8 `value:"-2"` - Int16P *int16 `value:"-39"` - Int32P *int32 `value:"-329"` - Int64P *int64 `value:"-3292"` - UintP *uint `value:"239"` - Uint8P *uint8 `value:"92"` - Uint16P *uint16 `value:"330"` - Uint32P *uint32 `value:"239239"` - Uint64P *uint64 `value:"3923"` - Float32P *float32 `value:"3.299"` - Float64P *float64 `value:"9.2"` - StringP *string `value:"foop"` - Complex64 *complex64 `value:"4+3i" want:"(4+3i)"` - Complex128 *complex128 `value:"3.9+2.6i" want:"(3.9+2.6i)"` - BoolP *bool `value:"true"` - IntSlice []int `value:"3,9,-10" want:"[3 9 -10]"` - IntArray [2]int `value:"22,11" want:"[22 11]"` - Foo Foo `value:"foo" want:"~foo~"` - FooArray [2]Foo `value:"a,b,c" want:"[~a~ ~b,c~]"` - FooP *Foo `value:"foo" want:"~foo~"` - SS1 []string `value:"foo/bar" want:"[foo/bar]"` - SS2 []string `value:"foo/bar" want:"[foo bar]" split:"/"` - SS3 []string `value:"foo,bar" want:"[foo,bar]" split:""` - SS4 []string `value:"foo,bar" want:"[foo bar]" split:","` - SA1 [2]string `value:"foo/bar" want:"[foo/bar ]"` - SA2 [2]string `value:"foo/bar" want:"[foo bar]" split:"/"` - SA3 [2]string `value:"foo,bar" want:"[foo,bar ]" split:""` - SS5 []string `value:"foo" want:"[foo bar]" value2:"bar"` - SS6 []string `value:"foo" want:"[bar]" value2:"bar" sa:"f"` + Int int `value:"38"` + Int8 int8 `value:"-9"` + Int16 int16 `value:"329"` + Int32 int32 `value:"-32902"` + Int64 int64 `value:"3292929"` + Uint uint `value:"202"` + Uint8 uint8 `value:"99"` + Uint16 uint16 `value:"3020"` + Uint32 uint32 `value:"92020"` + Uint64 uint64 `value:"320202"` + Float32 float32 `value:"3.9"` + Float64 float64 `value:"4.32e7" want:"4.32e+07"` + String string `value:"foo"` + Bool bool `value:"false"` + IntP *int `value:"-82"` + Int8P *int8 `value:"-2"` + Int16P *int16 `value:"-39"` + Int32P *int32 `value:"-329"` + Int64P *int64 `value:"-3292"` + UintP *uint `value:"239"` + Uint8P *uint8 `value:"92"` + Uint16P *uint16 `value:"330"` + Uint32P *uint32 `value:"239239"` + Uint64P *uint64 `value:"3923"` + Float32P *float32 `value:"3.299"` + Float64P *float64 `value:"9.2"` + StringP *string `value:"foop"` + Complex64 *complex64 `value:"4+3i" want:"(4+3i)"` + Complex128 *complex128 `value:"3.9+2.6i" want:"(3.9+2.6i)"` + BoolP *bool `value:"true"` + IntSlice []int `value:"3,9,-10" want:"[3 9 -10]"` + IntArray [2]int `value:"22,11" want:"[22 11]"` + Foo Foo `value:"foo" want:"~foo~"` + FooArray [2]Foo `value:"a,b,c" want:"[~a~ ~b,c~]"` + FooP *Foo `value:"foo" want:"~foo~"` + Dur time.Duration `value:"30m" want:"30m0s"` + DurP *time.Duration `value:"15m" want:"15m0s"` + DurArray []time.Duration `value:"15m,45m" want:"[15m0s 45m0s]"` + Bar Bar `value:"bar" want:"b/bar/e"` + BarArray [2]Bar `value:"a,b,c" want:"[b/a/e b/b,c/e]"` + BarP *Bar `value:"bar" want:"b/bar/e"` + SS1 []string `value:"foo/bar" want:"[foo/bar]"` + SS2 []string `value:"foo/bar" want:"[foo bar]" split:"/"` + SS3 []string `value:"foo,bar" want:"[foo,bar]" split:""` + SS4 []string `value:"foo,bar" want:"[foo bar]" split:","` + SA1 [2]string `value:"foo/bar" want:"[foo/bar ]"` + SA2 [2]string `value:"foo/bar" want:"[foo bar]" split:"/"` + SA3 [2]string `value:"foo,bar" want:"[foo,bar ]" split:""` + SS5 []string `value:"foo" want:"[foo bar]" value2:"bar"` + SS6 []string `value:"foo" want:"[bar]" value2:"bar" sa:"f"` } var ts tsType vp := reflect.ValueOf(&ts) v := reflect.Indirect(vp) var count int reflectutils.WalkStructElements(v.Type(), func(f reflect.StructField) bool { - t.Logf("field %s, a %s", f.Name, f.Type) - value, ok := f.Tag.Lookup("value") - if !assert.Truef(t, ok, "input value for %s", f.Name) { - return true - } - want, ok := f.Tag.Lookup("want") - if !ok { - want = value - } - var opts []reflectutils.StringSetterArg - if split, ok := f.Tag.Lookup("split"); ok { - t.Log(" splitting on", split) - opts = append(opts, reflectutils.WithSplitOn(split)) - } - if sa, ok := f.Tag.Lookup("sa"); ok { - b, err := strconv.ParseBool(sa) - require.NoError(t, err, "parse sa") - t.Log(" slice append", b) - opts = append(opts, reflectutils.SliceAppend(b)) - } + t.Run(f.Name+"-"+f.Type.String(), func(t *testing.T) { + t.Logf("field %s, a %s", f.Name, f.Type) + value, ok := f.Tag.Lookup("value") + if !assert.Truef(t, ok, "input value for %s", f.Name) { + return + } + want, ok := f.Tag.Lookup("want") + if !ok { + want = value + } + var opts []reflectutils.StringSetterArg + if split, ok := f.Tag.Lookup("split"); ok { + t.Log(" splitting on", split) + opts = append(opts, reflectutils.WithSplitOn(split)) + } + if sa, ok := f.Tag.Lookup("sa"); ok { + b, err := strconv.ParseBool(sa) + require.NoError(t, err, "parse sa") + t.Log(" slice append", b) + opts = append(opts, reflectutils.SliceAppend(b)) + } - fn, err := reflectutils.MakeStringSetter(f.Type, opts...) - if !assert.NoErrorf(t, err, "make string setter for %s", f.Name) { - return true - } - e := v.FieldByIndex(f.Index) - err = fn(e, value) - if assert.NoError(t, err, "set %s to '%s'", f.Name, value) { - value2, ok := f.Tag.Lookup("value2") - if ok { - err := fn(e, value2) - assert.NoError(t, err, "set value2") + fn, err := reflectutils.MakeStringSetter(f.Type, opts...) + if !assert.NoErrorf(t, err, "make string setter for %s", f.Name) { + return } - ge := e - if f.Type.Kind() == reflect.Ptr { - ge = e.Elem() + e := v.FieldByIndex(f.Index) + err = fn(e, value) + if assert.NoError(t, err, "set %s to '%s'", f.Name, value) { + value2, ok := f.Tag.Lookup("value2") + if ok { + err := fn(e, value2) + assert.NoError(t, err, "set value2") + } + ge := e + if f.Type.Kind() == reflect.Ptr { + ge = e.Elem() + } + assert.Equalf(t, want, fmt.Sprintf("%+v", ge.Interface()), "got setting %s to '%s'", f.Name, value) } - assert.Equalf(t, want, fmt.Sprintf("%+v", ge.Interface()), "got setting %s to '%s'", f.Name, value) - } - count++ + count++ + return + }) return true }) assert.Equal(t, v.NumField(), count, "number of fields tested")