diff --git a/README.markdown b/README.markdown index d418cbb2..7b7234b3 100644 --- a/README.markdown +++ b/README.markdown @@ -1,7 +1,6 @@ Redigo ====== -[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo) [![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis) Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database. diff --git a/redis/reflect.go b/redis/reflect.go new file mode 100644 index 00000000..e135aed7 --- /dev/null +++ b/redis/reflect.go @@ -0,0 +1,48 @@ +package redis + +import ( + "reflect" + "runtime" +) + +// methodName returns the name of the calling method, +// assumed to be two stack frames above. +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +// mustBe panics if f's kind is not expected. +func mustBe(v reflect.Value, expected reflect.Kind) { + if v.Kind() != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()}) + } +} + +// fieldByIndexCreate returns the nested field corresponding +// to index creating elements that are nil when stepping through. +// It panics if v is not a struct. +func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v +} diff --git a/redis/reflect_go117.go b/redis/reflect_go117.go new file mode 100644 index 00000000..e985192a --- /dev/null +++ b/redis/reflect_go117.go @@ -0,0 +1,34 @@ +//go:build !go1.18 +// +build !go1.18 + +package redis + +import ( + "errors" + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + if len(index) == 1 { + return v.Field(index[0]), nil + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name()) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v, nil +} diff --git a/redis/reflect_go118.go b/redis/reflect_go118.go new file mode 100644 index 00000000..3356e76f --- /dev/null +++ b/redis/reflect_go118.go @@ -0,0 +1,16 @@ +//go:build go1.18 +// +build go1.18 + +package redis + +import ( + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + return v.FieldByIndexErr(index) +} diff --git a/redis/scan.go b/redis/scan.go index 379206ed..82121011 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -355,7 +355,13 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec { return ss.m[string(name)] } -func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) { +func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec, seen map[reflect.Type]struct{}) error { + if _, ok := seen[t]; ok { + // Protect against infinite recursion. + return fmt.Errorf("recursive struct definition for %v", t) + } + + seen[t] = struct{}{} LOOP: for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -365,20 +371,21 @@ LOOP: case f.Anonymous: switch f.Type.Kind() { case reflect.Struct: - compileStructSpec(f.Type, depth, append(index, i), ss) + if err := compileStructSpec(f.Type, depth, append(index, i), ss, seen); err != nil { + return err + } case reflect.Ptr: - // TODO(steve): Protect against infinite recursion. if f.Type.Elem().Kind() == reflect.Struct { - compileStructSpec(f.Type.Elem(), depth, append(index, i), ss) + if err := compileStructSpec(f.Type.Elem(), depth, append(index, i), ss, seen); err != nil { + return err + } } } default: fs := &fieldSpec{name: f.Name} tag := f.Tag.Get("redis") - var ( - p string - ) + var p string first := true for len(tag) > 0 { i := strings.IndexByte(tag, ',') @@ -402,10 +409,12 @@ LOOP: } } } + d, found := depth[fs.name] if !found { d = 1 << 30 } + switch { case len(index) == d: // At same depth, remove from result. @@ -428,6 +437,8 @@ LOOP: } } } + + return nil } var ( @@ -435,26 +446,27 @@ var ( structSpecCache = make(map[reflect.Type]*structSpec) ) -func structSpecForType(t reflect.Type) *structSpec { - +func structSpecForType(t reflect.Type) (*structSpec, error) { structSpecMutex.RLock() ss, found := structSpecCache[t] structSpecMutex.RUnlock() if found { - return ss + return ss, nil } structSpecMutex.Lock() defer structSpecMutex.Unlock() ss, found = structSpecCache[t] if found { - return ss + return ss, nil } ss = &structSpec{m: make(map[string]*fieldSpec)} - compileStructSpec(t, make(map[string]int), nil, ss) + if err := compileStructSpec(t, make(map[string]int), nil, ss, make(map[reflect.Type]struct{})); err != nil { + return nil, fmt.Errorf("compile struct: %s: %w", t, err) + } structSpecCache[t] = ss - return ss + return ss, nil } var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct") @@ -480,30 +492,38 @@ func ScanStruct(src []interface{}, dest interface{}) error { if d.Kind() != reflect.Ptr || d.IsNil() { return errScanStructValue } + d = d.Elem() if d.Kind() != reflect.Struct { return errScanStructValue } - ss := structSpecForType(d.Type()) if len(src)%2 != 0 { return errors.New("redigo.ScanStruct: number of values not a multiple of 2") } + ss, err := structSpecForType(d.Type()) + if err != nil { + return fmt.Errorf("redigo.ScanStruct: %w", err) + } + for i := 0; i < len(src); i += 2 { s := src[i+1] if s == nil { continue } + name, ok := src[i].([]byte) if !ok { return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) } + fs := ss.fieldSpec(name) if fs == nil { continue } - if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + + if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil { return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) } } @@ -555,7 +575,11 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error return nil } - ss := structSpecForType(t) + ss, err := structSpecForType(t) + if err != nil { + return fmt.Errorf("redigo.ScanSlice: %w", err) + } + fss := ss.l if len(fieldNames) > 0 { fss = make([]*fieldSpec, len(fieldNames)) @@ -618,6 +642,7 @@ func (args Args) Add(value ...interface{}) Args { // for more information on the use of the 'redis' field tag. // // Other types are appended to args as is. +// panics if v includes a recursive anonymous struct. func (args Args) AddFlat(v interface{}) Args { rv := reflect.ValueOf(v) switch rv.Kind() { @@ -646,9 +671,17 @@ func (args Args) AddFlat(v interface{}) Args { } func flattenStruct(args Args, v reflect.Value) Args { - ss := structSpecForType(v.Type()) + ss, err := structSpecForType(v.Type()) + if err != nil { + panic(fmt.Errorf("redigo.AddFlat: %w", err)) + } + for _, fs := range ss.l { - fv := v.FieldByIndex(fs.index) + fv, err := fieldByIndexErr(v, fs.index) + if err != nil { + // Nil item ignore. + continue + } if fs.omitEmpty { var empty = false switch fv.Kind() { diff --git a/redis/scan_test.go b/redis/scan_test.go index d43c372e..53556ebb 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -233,12 +233,15 @@ type s1 struct { Sdp *durationScan `redis:"sdp"` } -var boolTrue = true +var ( + boolTrue = true + int5 = 5 +) var scanStructTests = []struct { - title string - reply []string - value interface{} + name string + reply []string + expected interface{} }{ {"basic", []string{ @@ -273,25 +276,54 @@ var scanStructTests = []struct { []string{}, &s1{}, }, + {"struct-anonymous-nil", + []string{"edi", "2"}, + &struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-early", + []string{"edi", "2"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-late", + []string{"edi", "2", "ed2i", "3", "edp2i", "4"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }, + }, } func TestScanStruct(t *testing.T) { for _, tt := range scanStructTests { + t.Run(tt.name, func(t *testing.T) { + reply := make([]interface{}, len(tt.reply)) + for i, v := range tt.reply { + reply[i] = []byte(v) + } - var reply []interface{} - for _, v := range tt.reply { - reply = append(reply, []byte(v)) - } - - value := reflect.New(reflect.ValueOf(tt.value).Type().Elem()) - - if err := redis.ScanStruct(reply, value.Interface()); err != nil { - t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err) - } - - if !reflect.DeepEqual(value.Interface(), tt.value) { - t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value) - } + value := reflect.New(reflect.ValueOf(tt.expected).Type().Elem()).Interface() + err := redis.ScanStruct(reply, value) + require.NoError(t, err) + require.Equal(t, tt.expected, value) + }) } } @@ -486,73 +518,182 @@ type Edp struct { EdpI int `redis:"edpi"` } +type Ed2 struct { + Ed2I int `redis:"ed2i"` + *Edp2 +} + +type Edp2 struct { + Edp2I int `redis:"edp2i"` + *Edp +} + +type Edpr1 struct { + Edpr1I int `redis:"edpr1i"` + *Edpr2 +} + +type Edpr2 struct { + Edpr2I int `redis:"edpr2i"` + *Edpr1 +} + var argsTests = []struct { title string - actual redis.Args + fn func() redis.Args expected redis.Args + panics bool }{ {"struct-ptr", - redis.Args{}.AddFlat(&struct { - I int `redis:"i"` - U uint `redis:"u"` - S string `redis:"s"` - P []byte `redis:"p"` - M map[string]string `redis:"m"` - Bt bool - Bf bool - PtrB *bool - PtrI *int - }{ - -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, - }), - redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true}, + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + PtrB *bool + PtrI *int + PtrI2 *int + }{ + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, + }) + }, + redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5}, + false, }, {"struct", - redis.Args{}.AddFlat(struct{ I int }{123}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ I int }{123}) + }, redis.Args{"I", 123}, + false, }, {"struct-with-RedisArg-direct", - redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T CustomTime }{CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-direct-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTime }{&CustomTime{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"struct-with-RedisArg-ptr", - redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}), + func() redis.Args { + return redis.Args{}.AddFlat(struct{ T *CustomTimePtr }{&CustomTimePtr{Time: time.Unix(1573231058, 0)}}) + }, redis.Args{"T", int64(1573231058)}, + false, }, {"slice", - redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2), + func() redis.Args { + return redis.Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2) + }, redis.Args{1, "a", "b", "c", 2}, + false, }, {"struct-omitempty", - redis.Args{}.AddFlat(&struct { - Sdp *durationArg `redis:"Sdp,omitempty"` - }{ - nil, - }), + func() redis.Args { + return redis.Args{}.AddFlat(&struct { + Sdp *durationArg `redis:"Sdp,omitempty"` + }{ + nil, + }) + }, redis.Args{}, + false, }, {"struct-anonymous", - redis.Args{}.AddFlat(struct { - Ed - *Edp - }{ - Ed{EdI: 2}, - &Edp{EdpI: 3}, - }), + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed{EdI: 2}, + &Edp{EdpI: 3}, + }) + }, redis.Args{"edi", 2, "edpi", 3}, + false, + }, + {"struct-anonymous-nil", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }) + }, + redis.Args{"edi", 2}, + false, + }, + {"struct-anonymous-multi-nil-early", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }) + }, + redis.Args{"edi", 2}, + false, + }, + {"struct-anonymous-multi-nil-late", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }) + }, + redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4}, + false, + }, + {"struct-recursive-ptr", + func() redis.Args { + return redis.Args{}.AddFlat(struct { + Edpr1 + }{ + Edpr1: Edpr1{ + Edpr1I: 1, + Edpr2: &Edpr2{ + Edpr2I: 2, + Edpr1: &Edpr1{ + Edpr1I: 10, + }, + }, + }, + }) + }, + redis.Args{}, + true, }, } func TestArgs(t *testing.T) { for _, tt := range argsTests { t.Run(tt.title, func(t *testing.T) { - if !reflect.DeepEqual(tt.actual, tt.expected) { - t.Fatalf("is %v, want %v", tt.actual, tt.expected) + if tt.panics { + require.Panics(t, func() { tt.fn() }) + return } + require.Equal(t, tt.expected, tt.fn()) }) } }