From f1e923c7e2cbb9bc8c294c858074a7c2b8a695cb Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Wed, 6 Jul 2022 12:23:14 +0100 Subject: [PATCH] fix: Anonymous field pointers (#622) Fix panic using ArgsFlat or ScanStruct on structs with nil Anonymous field pointers. Catch the anonymous struct recursion and prevent it. In the case of ScanStruct an error will be returned, in the case of ArgsFlat it will panic with a nice error. --- README.markdown | 1 - redis/reflect.go | 48 ++++++++ redis/reflect_go117.go | 34 ++++++ redis/reflect_go118.go | 16 +++ redis/scan.go | 69 +++++++++--- redis/scan_test.go | 245 ++++++++++++++++++++++++++++++++--------- 6 files changed, 342 insertions(+), 71 deletions(-) create mode 100644 redis/reflect.go create mode 100644 redis/reflect_go117.go create mode 100644 redis/reflect_go118.go diff --git a/README.markdown b/README.markdown index d418cbb2..7b7234b3 100644 --- a/README.markdown +++ b/README.markdown @@ -1,7 +1,6 @@ Redigo ====== -[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo) [![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis) Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database. diff --git a/redis/reflect.go b/redis/reflect.go new file mode 100644 index 00000000..e135aed7 --- /dev/null +++ b/redis/reflect.go @@ -0,0 +1,48 @@ +package redis + +import ( + "reflect" + "runtime" +) + +// methodName returns the name of the calling method, +// assumed to be two stack frames above. +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +// mustBe panics if f's kind is not expected. +func mustBe(v reflect.Value, expected reflect.Kind) { + if v.Kind() != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()}) + } +} + +// fieldByIndexCreate returns the nested field corresponding +// to index creating elements that are nil when stepping through. +// It panics if v is not a struct. +func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v +} diff --git a/redis/reflect_go117.go b/redis/reflect_go117.go new file mode 100644 index 00000000..e985192a --- /dev/null +++ b/redis/reflect_go117.go @@ -0,0 +1,34 @@ +//go:build !go1.18 +// +build !go1.18 + +package redis + +import ( + "errors" + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + if len(index) == 1 { + return v.Field(index[0]), nil + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name()) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v, nil +} diff --git a/redis/reflect_go118.go b/redis/reflect_go118.go new file mode 100644 index 00000000..3356e76f --- /dev/null +++ b/redis/reflect_go118.go @@ -0,0 +1,16 @@ +//go:build go1.18 +// +build go1.18 + +package redis + +import ( + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + return v.FieldByIndexErr(index) +} diff --git a/redis/scan.go b/redis/scan.go index 379206ed..82121011 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -355,7 +355,13 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { return ss.m[string(name)] } -func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { +func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec, seen map[reflect.Type]struct{}) error { + if _, ok := seen[t]; ok { + // Protect against infinite recursion. + return fmt.Errorf("recursive struct definition for %v", t) + } + + seen[t] = struct{}{} LOOP: for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -365,20 +371,21 @@ LOOP: case f.Anonymous: switch f.Type.Kind() { case reflect.Struct: - compileStructSpec(f.Type, depth, append(index, i), ss) + if err := compileStructSpec(f.Type, depth, append(index, i), ss, seen); err != nil { + return err + } case reflect.Ptr: - // TODO(steve): Protect against infinite recursion. if f.Type.Elem().Kind() == reflect.Struct { - compileStructSpec(f.Type.Elem(), depth, append(index, i), ss) + if err := compileStructSpec(f.Type.Elem(), depth, append(index, i), ss, seen); err != nil { + return err + } } } default: fs := &fieldSpec{name: f.Name} tag := f.Tag.Get("redis") - var ( - p string - ) + var p string first := true for len(tag) > 0 { i := strings.IndexByte(tag, ',') @@ -402,10 +409,12 @@ LOOP: } } } + d, found := depth[fs.name] if !found { d = 1 << 30 } + switch { case len(index) == d: // At same depth, remove from result. @@ -428,6 +437,8 @@ LOOP: } } } + + return nil } var ( @@ -435,26 +446,27 @@ var ( structSpecCache = make(map[reflect.Type]*structSpec) ) -func structSpecForType(t reflect.Type) *structSpec { - +func structSpecForType(t reflect.Type) (*structSpec, error) { structSpecMutex.RLock() ss, found := structSpecCache[t] structSpecMutex.RUnlock() if found { - return ss + return ss, nil } structSpecMutex.Lock() defer structSpecMutex.Unlock() ss, found = structSpecCache[t] if found { - return ss + return ss, nil } ss = &structSpec{m: make(map[string]*fieldSpec)} - compileStructSpec(t, make(map[string]int), nil, ss) + if err := compileStructSpec(t, make(map[string]int), nil, ss, make(map[reflect.Type]struct{})); err != nil { + return nil, fmt.Errorf("compile struct: %s: %w", t, err) + } structSpecCache[t] = ss - return ss + return ss, nil } var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct") @@ -480,30 +492,38 @@ func ScanStruct(src []interface{}, dest interface{}) error { if d.Kind() != reflect.Ptr || d.IsNil() { return errScanStructValue } + d = d.Elem() if d.Kind() != reflect.Struct { return errScanStructValue } - ss := structSpecForType(d.Type()) if len(src)%2 != 0 { return errors.New("redigo.ScanStruct: number of values not a multiple of 2") } + ss, err := structSpecForType(d.Type()) + if err != nil { + return fmt.Errorf("redigo.ScanStruct: %w", err) + } + for i := 0; i < len(src); i += 2 { s := src[i+1] if s == nil { continue } + name, ok := src[i].([]byte) if !ok { return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) } + fs := ss.fieldSpec(name) if fs == nil { continue } - if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + + if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil { return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) } } @@ -555,7 +575,11 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error return nil } - ss := structSpecForType(t) + ss, err := structSpecForType(t) + if err != nil { + return fmt.Errorf("redigo.ScanSlice: %w", err) + } + fss := ss.l if len(fieldNames) > 0 { fss = make([]*fieldSpec, len(fieldNames)) @@ -618,6 +642,7 @@ func (args Args) Add(value ...interface{}) Args { // for more information on the use of the 'redis' field tag. // // Other types are appended to args as is. +// panics if v includes a recursive anonymous struct. func (args Args) AddFlat(v interface{}) Args { rv := reflect.ValueOf(v) switch rv.Kind() { @@ -646,9 +671,17 @@ func (args Args) AddFlat(v interface{}) Args { } func flattenStruct(args Args, v reflect.Value) Args { - ss := structSpecForType(v.Type()) + ss, err := structSpecForType(v.Type()) + if err != nil { + panic(fmt.Errorf("redigo.AddFlat: %w", err)) + } + for _, fs := range ss.l { - fv := v.FieldByIndex(fs.index) + fv, err := fieldByIndexErr(v, fs.index) + if err != nil { + // Nil item ignore. + continue + } if fs.omitEmpty { var empty = false switch fv.Kind() { diff --git a/redis/scan_test.go b/redis/scan_test.go index d43c372e..53556ebb 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -233,12 +233,15 @@ type s1 struct { Sdp *durationScan `redis:"sdp"` } -var boolTrue = true +var ( + boolTrue = true + int5 = 5 +) var scanStructTests = []struct { - title string - reply []string - value interface{} + name string + reply []string + expected interface{} }{ {"basic", []string{ @@ -273,25 +276,54 @@ var scanStructTests = []struct { []string{}, &s1{}, }, + {"struct-anonymous-nil", + []string{"edi", "2"}, + &struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-early", + []string{"edi", "2"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-late", + []string{"edi", "2", "ed2i", "3", "edp2i", "4"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }, + }, } func TestScanStruct(t *testing.T) { for _, tt := range scanStructTests { + t.Run(tt.name, func(t *testing.T) { + reply := make([]interface{}, len(tt.reply)) + for i, v := range tt.reply { + reply[i] = []byte(v) + } - var reply []interface{} - for _, v := range tt.reply { - reply = append(reply, []byte(v)) - } - - value := reflect.New(reflect.ValueOf(tt.value).Type().Elem()) - - if err := redis.ScanStruct(reply, value.Interface()); err != nil { - t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err) - } - - if !reflect.DeepEqual(value.Interface(), tt.value) { - t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value) - } + value := reflect.New(reflect.ValueOf(tt.expected).Type().Elem()).Interface() + err := redis.ScanStruct(reply, value) + require.NoError(t, err) + require.Equal(t, tt.expected, value) + }) } } @@ -486,73 +518,182 @@ type Edp struct { EdpI int `redis:"edpi"` } +type Ed2 struct { + Ed2I int `redis:"ed2i"` + *Edp2 +} + +type Edp2 struct { + Edp2I int `redis:"edp2i"` + *Edp +} + +type Edpr1 struct { + Edpr1I int `redis:"edpr1i"` + *Edpr2 +} + +type Edpr2 struct { + Edpr2I int `redis:"edpr2i"` + *Edpr1 +} + var argsTests = []struct { title string - actual redis.Args + fn func() redis.Args expected redis.Args + panics bool }{ {"struct-ptr", - redis.Args{}.AddFlat(&struct { - I int `redis:"i"` - U uint `redis:"u"` - S string `redis:"s"` - P []byte `redis:"p"` - M map[string]string `redis:"m"` - Bt bool - Bf bool - PtrB *bool - PtrI *int - }{ - -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, - }), - redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true}, + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + PtrB *bool + PtrI *int + PtrI2 *int + }{ + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, + }) + }, + redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5}, + false, }, {"struct", - redis.Args{}.AddFlat(struct{ I int }{123}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ I int }{123}) + }, redis.Args{"I", 123}, + false, }, {"struct-with-RedisArg-direct", - redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-direct-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"slice", - redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2), + func() redis.Args { + return redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2) + }, redis.Args{1, "a", "b", "c", 2}, + false, }, {"struct-omitempty", - redis.Args{}.AddFlat(&struct { - Sdp *durationArg `redis:"Sdp,omitempty"` - }{ - nil, - }), + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + Sdp *durationArg `redis:"Sdp,omitempty"` + }{ + nil, + }) + }, redis.Args{}, + false, }, {"struct-anonymous", - redis.Args{}.AddFlat(struct { - Ed - *Edp - }{ - Ed{EdI: 2}, - &Edp{EdpI: 3}, - }), + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed{EdI: 2}, + &Edp{EdpI: 3}, + }) + }, redis.Args{"edi", 2, "edpi", 3}, + false, + }, + {"struct-anonymous-nil", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }) + }, + redis.Args{"edi", 2}, + false, + }, + {"struct-anonymous-multi-nil-early", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }) + }, + redis.Args{"edi", 2}, + false, + }, + {"struct-anonymous-multi-nil-late", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }) + }, + redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4}, + false, + }, + {"struct-recursive-ptr", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Edpr1 + }{ + Edpr1: Edpr1{ + Edpr1I: 1, + Edpr2: &Edpr2{ + Edpr2I: 2, + Edpr1: &Edpr1{ + Edpr1I: 10, + }, + }, + }, + }) + }, + redis.Args{}, + true, }, } func TestArgs(t *testing.T) { for _, tt := range argsTests { t.Run(tt.title, func(t *testing.T) { - if !reflect.DeepEqual(tt.actual, tt.expected) { - t.Fatalf("is %v, want %v", tt.actual, tt.expected) + if tt.panics { + require.Panics(t, func() { tt.fn() }) + return } + require.Equal(t, tt.expected, tt.fn()) }) } }