From 7f26272eae515e0c5f8cf275e74a1e29bb59c903 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Fri, 1 Jul 2022 15:01:24 +0100 Subject: [PATCH 1/3] fix: Anonymous field pointers Fix panic using ArgsFlat or ScanStruct on structs with nil Anonymous field pointers. Also: * Remove unused travis-ci link from README. Fixes: #621 --- README.markdown | 1 - redis/reflect.go | 48 ++++++++++++++ redis/reflect_go117.go | 34 ++++++++++ redis/reflect_go118.go | 16 +++++ redis/scan.go | 21 ++++--- redis/scan_test.go | 138 +++++++++++++++++++++++++++++++---------- 6 files changed, 218 insertions(+), 40 deletions(-) create mode 100644 redis/reflect.go create mode 100644 redis/reflect_go117.go create mode 100644 redis/reflect_go118.go diff --git a/README.markdown b/README.markdown index d418cbb2..7b7234b3 100644 --- a/README.markdown +++ b/README.markdown @@ -1,7 +1,6 @@ Redigo ====== -[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo) [![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis) Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database. diff --git a/redis/reflect.go b/redis/reflect.go new file mode 100644 index 00000000..e135aed7 --- /dev/null +++ b/redis/reflect.go @@ -0,0 +1,48 @@ +package redis + +import ( + "reflect" + "runtime" +) + +// methodName returns the name of the calling method, +// assumed to be two stack frames above. +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +// mustBe panics if f's kind is not expected. +func mustBe(v reflect.Value, expected reflect.Kind) { + if v.Kind() != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()}) + } +} + +// fieldByIndexCreate returns the nested field corresponding +// to index creating elements that are nil when stepping through. +// It panics if v is not a struct. +func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v +} diff --git a/redis/reflect_go117.go b/redis/reflect_go117.go new file mode 100644 index 00000000..e143ae41 --- /dev/null +++ b/redis/reflect_go117.go @@ -0,0 +1,34 @@ +//go:build go1.17 && !go1.18 +// +build go1.17,!go1.18 + +package redis + +import ( + "errors" + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + if len(index) == 1 { + return v.Field(index[0]), nil + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name()) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v, nil +} diff --git a/redis/reflect_go118.go b/redis/reflect_go118.go new file mode 100644 index 00000000..3356e76f --- /dev/null +++ b/redis/reflect_go118.go @@ -0,0 +1,16 @@ +//go:build go1.18 +// +build go1.18 + +package redis + +import ( + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + return v.FieldByIndexErr(index) +} diff --git a/redis/scan.go b/redis/scan.go index 379206ed..27809aea 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -376,9 +376,7 @@ LOOP: fs := &fieldSpec{name: f.Name} tag := f.Tag.Get("redis") - var ( - p string - ) + var p string first := true for len(tag) > 0 { i := strings.IndexByte(tag, ',') @@ -402,10 +400,12 @@ LOOP: } } } + d, found := depth[fs.name] if !found { d = 1 << 30 } + switch { case len(index) == d: // At same depth, remove from result. @@ -436,7 +436,6 @@ var ( ) func structSpecForType(t reflect.Type) *structSpec { - structSpecMutex.RLock() ss, found := structSpecCache[t] structSpecMutex.RUnlock() @@ -480,30 +479,34 @@ func ScanStruct(src []interface{}, dest interface{}) error { if d.Kind() != reflect.Ptr || d.IsNil() { return errScanStructValue } + d = d.Elem() if d.Kind() != reflect.Struct { return errScanStructValue } - ss := structSpecForType(d.Type()) if len(src)%2 != 0 { return errors.New("redigo.ScanStruct: number of values not a multiple of 2") } + ss := structSpecForType(d.Type()) for i := 0; i < len(src); i += 2 { s := src[i+1] if s == nil { continue } + name, ok := src[i].([]byte) if !ok { return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) } + fs := ss.fieldSpec(name) if fs == nil { continue } - if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + + if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil { return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) } } @@ -648,7 +651,11 @@ func (args Args) AddFlat(v interface{}) Args { func flattenStruct(args Args, v reflect.Value) Args { ss := structSpecForType(v.Type()) for _, fs := range ss.l { - fv := v.FieldByIndex(fs.index) + fv, err := fieldByIndexErr(v, fs.index) + if err != nil { + // Nil item ignore. + continue + } if fs.omitEmpty { var empty = false switch fv.Kind() { diff --git a/redis/scan_test.go b/redis/scan_test.go index d43c372e..61541715 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -233,12 +233,15 @@ type s1 struct { Sdp *durationScan `redis:"sdp"` } -var boolTrue = true +var ( + boolTrue = true + int5 = 5 +) var scanStructTests = []struct { - title string - reply []string - value interface{} + name string + reply []string + expected interface{} }{ {"basic", []string{ @@ -273,25 +276,54 @@ var scanStructTests = []struct { []string{}, &s1{}, }, + {"struct-anonymous-nil", + []string{"edi", "2"}, + &struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-early", + []string{"edi", "2"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-late", + []string{"edi", "2", "ed2i", "3", "edp2i", "4"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }, + }, } func TestScanStruct(t *testing.T) { for _, tt := range scanStructTests { + t.Run(tt.name, func(t *testing.T) { + reply := make([]interface{}, len(tt.reply)) + for i, v := range tt.reply { + reply[i] = []byte(v) + } - var reply []interface{} - for _, v := range tt.reply { - reply = append(reply, []byte(v)) - } - - value := reflect.New(reflect.ValueOf(tt.value).Type().Elem()) - - if err := redis.ScanStruct(reply, value.Interface()); err != nil { - t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err) - } - - if !reflect.DeepEqual(value.Interface(), tt.value) { - t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value) - } + value := reflect.New(reflect.ValueOf(tt.expected).Type().Elem()).Interface() + err := redis.ScanStruct(reply, value) + require.NoError(t, err) + require.Equal(t, tt.expected, value) + }) } } @@ -486,6 +518,16 @@ type Edp struct { EdpI int `redis:"edpi"` } +type Ed2 struct { + Ed2I int `redis:"ed2i"` + *Edp2 +} + +type Edp2 struct { + Edp2I int `redis:"edp2i"` + *Edp +} + var argsTests = []struct { title string actual redis.Args @@ -493,19 +535,20 @@ var argsTests = []struct { }{ {"struct-ptr", redis.Args{}.AddFlat(&struct { - I int `redis:"i"` - U uint `redis:"u"` - S string `redis:"s"` - P []byte `redis:"p"` - M map[string]string `redis:"m"` - Bt bool - Bf bool - PtrB *bool - PtrI *int + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + PtrB *bool + PtrI *int + PtrI2 *int }{ - -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, }), - redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true}, + redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5}, }, {"struct", redis.Args{}.AddFlat(struct{ I int }{123}), @@ -545,14 +588,45 @@ var argsTests = []struct { }), redis.Args{"edi", 2, "edpi", 3}, }, + {"struct-anonymous-nil", + redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }), + redis.Args{"edi", 2}, + }, + {"struct-anonymous-multi-nil-early", + redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }), + redis.Args{"edi", 2}, + }, + {"struct-anonymous-multi-nil-late", + redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }), + redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4}, + }, } func TestArgs(t *testing.T) { for _, tt := range argsTests { t.Run(tt.title, func(t *testing.T) { - if !reflect.DeepEqual(tt.actual, tt.expected) { - t.Fatalf("is %v, want %v", tt.actual, tt.expected) - } + require.Equal(t, tt.expected, tt.actual) }) } } From fee9988a7149f6795118086b47b4cd5c59d71e40 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Mon, 4 Jul 2022 14:22:25 +0100 Subject: [PATCH 2/3] fix: builds on go before 1.18 Remove the incorrect go1.17 tag so that builds on go versions below v1.18 work as expected. --- redis/reflect_go117.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/reflect_go117.go b/redis/reflect_go117.go index e143ae41..e985192a 100644 --- a/redis/reflect_go117.go +++ b/redis/reflect_go117.go @@ -1,5 +1,5 @@ -//go:build go1.17 && !go1.18 -// +build go1.17,!go1.18 +//go:build !go1.18 +// +build !go1.18 package redis From 58130415db19a47ede70103837dd58c8578fac0c Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Tue, 5 Jul 2022 21:30:31 +0100 Subject: [PATCH 3/3] fix: anonymous struct recursion Catch the case where anonymous struct recursion and prevent it. In the case of ScanStruct an error will be returned, in the case of ArgsFlat it will panic with a nice error. --- redis/scan.go | 50 ++++++++++--- redis/scan_test.go | 179 +++++++++++++++++++++++++++++++-------------- 2 files changed, 161 insertions(+), 68 deletions(-) diff --git a/redis/scan.go b/redis/scan.go index 27809aea..82121011 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -355,7 +355,13 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { return ss.m[string(name)] } -func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { +func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec, seen map[reflect.Type]struct{}) error { + if _, ok := seen[t]; ok { + // Protect against infinite recursion. + return fmt.Errorf("recursive struct definition for %v", t) + } + + seen[t] = struct{}{} LOOP: for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -365,11 +371,14 @@ LOOP: case f.Anonymous: switch f.Type.Kind() { case reflect.Struct: - compileStructSpec(f.Type, depth, append(index, i), ss) + if err := compileStructSpec(f.Type, depth, append(index, i), ss, seen); err != nil { + return err + } case reflect.Ptr: - // TODO(steve): Protect against infinite recursion. if f.Type.Elem().Kind() == reflect.Struct { - compileStructSpec(f.Type.Elem(), depth, append(index, i), ss) + if err := compileStructSpec(f.Type.Elem(), depth, append(index, i), ss, seen); err != nil { + return err + } } } default: @@ -428,6 +437,8 @@ LOOP: } } } + + return nil } var ( @@ -435,25 +446,27 @@ var ( structSpecCache = make(map[reflect.Type]*structSpec) ) -func structSpecForType(t reflect.Type) *structSpec { +func structSpecForType(t reflect.Type) (*structSpec, error) { structSpecMutex.RLock() ss, found := structSpecCache[t] structSpecMutex.RUnlock() if found { - return ss + return ss, nil } structSpecMutex.Lock() defer structSpecMutex.Unlock() ss, found = structSpecCache[t] if found { - return ss + return ss, nil } ss = &structSpec{m: make(map[string]*fieldSpec)} - compileStructSpec(t, make(map[string]int), nil, ss) + if err := compileStructSpec(t, make(map[string]int), nil, ss, make(map[reflect.Type]struct{})); err != nil { + return nil, fmt.Errorf("compile struct: %s: %w", t, err) + } structSpecCache[t] = ss - return ss + return ss, nil } var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct") @@ -489,7 +502,11 @@ func ScanStruct(src []interface{}, dest interface{}) error { return errors.New("redigo.ScanStruct: number of values not a multiple of 2") } - ss := structSpecForType(d.Type()) + ss, err := structSpecForType(d.Type()) + if err != nil { + return fmt.Errorf("redigo.ScanStruct: %w", err) + } + for i := 0; i < len(src); i += 2 { s := src[i+1] if s == nil { @@ -558,7 +575,11 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error return nil } - ss := structSpecForType(t) + ss, err := structSpecForType(t) + if err != nil { + return fmt.Errorf("redigo.ScanSlice: %w", err) + } + fss := ss.l if len(fieldNames) > 0 { fss = make([]*fieldSpec, len(fieldNames)) @@ -621,6 +642,7 @@ func (args Args) Add(value ...interface{}) Args { // for more information on the use of the 'redis' field tag. // // Other types are appended to args as is. +// panics if v includes a recursive anonymous struct. func (args Args) AddFlat(v interface{}) Args { rv := reflect.ValueOf(v) switch rv.Kind() { @@ -649,7 +671,11 @@ func (args Args) AddFlat(v interface{}) Args { } func flattenStruct(args Args, v reflect.Value) Args { - ss := structSpecForType(v.Type()) + ss, err := structSpecForType(v.Type()) + if err != nil { + panic(fmt.Errorf("redigo.AddFlat: %w", err)) + } + for _, fs := range ss.l { fv, err := fieldByIndexErr(v, fs.index) if err != nil { diff --git a/redis/scan_test.go b/redis/scan_test.go index 61541715..53556ebb 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -528,105 +528,172 @@ type Edp2 struct { *Edp } +type Edpr1 struct { + Edpr1I int `redis:"edpr1i"` + *Edpr2 +} + +type Edpr2 struct { + Edpr2I int `redis:"edpr2i"` + *Edpr1 +} + var argsTests = []struct { title string - actual redis.Args + fn func() redis.Args expected redis.Args + panics bool }{ {"struct-ptr", - redis.Args{}.AddFlat(&struct { - I int `redis:"i"` - U uint `redis:"u"` - S string `redis:"s"` - P []byte `redis:"p"` - M map[string]string `redis:"m"` - Bt bool - Bf bool - PtrB *bool - PtrI *int - PtrI2 *int - }{ - -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, - }), + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + PtrB *bool + PtrI *int + PtrI2 *int + }{ + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, + }) + }, redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5}, + false, }, {"struct", - redis.Args{}.AddFlat(struct{ I int }{123}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ I int }{123}) + }, redis.Args{"I", 123}, + false, }, {"struct-with-RedisArg-direct", - redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-direct-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"slice", - redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2), + func() redis.Args { + return redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2) + }, redis.Args{1, "a", "b", "c", 2}, + false, }, {"struct-omitempty", - redis.Args{}.AddFlat(&struct { - Sdp *durationArg `redis:"Sdp,omitempty"` - }{ - nil, - }), + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + Sdp *durationArg `redis:"Sdp,omitempty"` + }{ + nil, + }) + }, redis.Args{}, + false, }, {"struct-anonymous", - redis.Args{}.AddFlat(struct { - Ed - *Edp - }{ - Ed{EdI: 2}, - &Edp{EdpI: 3}, - }), + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed{EdI: 2}, + &Edp{EdpI: 3}, + }) + }, redis.Args{"edi", 2, "edpi", 3}, + false, }, {"struct-anonymous-nil", - redis.Args{}.AddFlat(struct { - Ed - *Edp - }{ - Ed: Ed{EdI: 2}, - }), + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }) + }, redis.Args{"edi", 2}, + false, }, {"struct-anonymous-multi-nil-early", - redis.Args{}.AddFlat(struct { - Ed - *Ed2 - }{ - Ed: Ed{EdI: 2}, - }), + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }) + }, redis.Args{"edi", 2}, + false, }, {"struct-anonymous-multi-nil-late", - redis.Args{}.AddFlat(struct { - Ed - *Ed2 - }{ - Ed: Ed{EdI: 2}, - Ed2: &Ed2{ - Ed2I: 3, - Edp2: &Edp2{ - Edp2I: 4, + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, }, - }, - }), + }) + }, redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4}, + false, + }, + {"struct-recursive-ptr", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Edpr1 + }{ + Edpr1: Edpr1{ + Edpr1I: 1, + Edpr2: &Edpr2{ + Edpr2I: 2, + Edpr1: &Edpr1{ + Edpr1I: 10, + }, + }, + }, + }) + }, + redis.Args{}, + true, }, } func TestArgs(t *testing.T) { for _, tt := range argsTests { t.Run(tt.title, func(t *testing.T) { - require.Equal(t, tt.expected, tt.actual) + if tt.panics { + require.Panics(t, func() { tt.fn() }) + return + } + require.Equal(t, tt.expected, tt.fn()) }) } }