diff --git a/README.markdown b/README.markdown index d418cbb2..7b7234b3 100644 --- a/README.markdown +++ b/README.markdown @@ -1,7 +1,6 @@ Redigo ====== -[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo) [![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis) Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database. diff --git a/redis/reflect.go b/redis/reflect.go new file mode 100644 index 00000000..e135aed7 --- /dev/null +++ b/redis/reflect.go @@ -0,0 +1,48 @@ +package redis + +import ( + "reflect" + "runtime" +) + +// methodName returns the name of the calling method, +// assumed to be two stack frames above. +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +// mustBe panics if f's kind is not expected. +func mustBe(v reflect.Value, expected reflect.Kind) { + if v.Kind() != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()}) + } +} + +// fieldByIndexCreate returns the nested field corresponding +// to index creating elements that are nil when stepping through. +// It panics if v is not a struct. +func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v +} diff --git a/redis/reflect_go117.go b/redis/reflect_go117.go new file mode 100644 index 00000000..e143ae41 --- /dev/null +++ b/redis/reflect_go117.go @@ -0,0 +1,34 @@ +//go:build go1.17 && !go1.18 +// +build go1.17,!go1.18 + +package redis + +import ( + "errors" + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + if len(index) == 1 { + return v.Field(index[0]), nil + } + + mustBe(v, reflect.Struct) + for i, x := range index { + if i > 0 { + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name()) + } + v = v.Elem() + } + } + v = v.Field(x) + } + + return v, nil +} diff --git a/redis/reflect_go118.go b/redis/reflect_go118.go new file mode 100644 index 00000000..3356e76f --- /dev/null +++ b/redis/reflect_go118.go @@ -0,0 +1,16 @@ +//go:build go1.18 +// +build go1.18 + +package redis + +import ( + "reflect" +) + +// fieldByIndexErr returns the nested field corresponding to index. +// It returns an error if evaluation requires stepping through a nil +// pointer, but panics if it must step through a field that +// is not a struct. +func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) { + return v.FieldByIndexErr(index) +} diff --git a/redis/scan.go b/redis/scan.go index 379206ed..27809aea 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -376,9 +376,7 @@ LOOP: fs := &fieldSpec{name: f.Name} tag := f.Tag.Get("redis") - var ( - p string - ) + var p string first := true for len(tag) > 0 { i := strings.IndexByte(tag, ',') @@ -402,10 +400,12 @@ LOOP: } } } + d, found := depth[fs.name] if !found { d = 1 << 30 } + switch { case len(index) == d: // At same depth, remove from result. @@ -436,7 +436,6 @@ var ( ) func structSpecForType(t reflect.Type) *structSpec { - structSpecMutex.RLock() ss, found := structSpecCache[t] structSpecMutex.RUnlock() @@ -480,30 +479,34 @@ func ScanStruct(src []interface{}, dest interface{}) error { if d.Kind() != reflect.Ptr || d.IsNil() { return errScanStructValue } + d = d.Elem() if d.Kind() != reflect.Struct { return errScanStructValue } - ss := structSpecForType(d.Type()) if len(src)%2 != 0 { return errors.New("redigo.ScanStruct: number of values not a multiple of 2") } + ss := structSpecForType(d.Type()) for i := 0; i < len(src); i += 2 { s := src[i+1] if s == nil { continue } + name, ok := src[i].([]byte) if !ok { return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) } + fs := ss.fieldSpec(name) if fs == nil { continue } - if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil { + + if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil { return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err) } } @@ -648,7 +651,11 @@ func (args Args) AddFlat(v interface{}) Args { func flattenStruct(args Args, v reflect.Value) Args { ss := structSpecForType(v.Type()) for _, fs := range ss.l { - fv := v.FieldByIndex(fs.index) + fv, err := fieldByIndexErr(v, fs.index) + if err != nil { + // Nil item ignore. + continue + } if fs.omitEmpty { var empty = false switch fv.Kind() { diff --git a/redis/scan_test.go b/redis/scan_test.go index d43c372e..61541715 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -233,12 +233,15 @@ type s1 struct { Sdp *durationScan `redis:"sdp"` } -var boolTrue = true +var ( + boolTrue = true + int5 = 5 +) var scanStructTests = []struct { - title string - reply []string - value interface{} + name string + reply []string + expected interface{} }{ {"basic", []string{ @@ -273,25 +276,54 @@ var scanStructTests = []struct { []string{}, &s1{}, }, + {"struct-anonymous-nil", + []string{"edi", "2"}, + &struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-early", + []string{"edi", "2"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }, + }, + {"struct-anonymous-multi-nil-late", + []string{"edi", "2", "ed2i", "3", "edp2i", "4"}, + &struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }, + }, } func TestScanStruct(t *testing.T) { for _, tt := range scanStructTests { + t.Run(tt.name, func(t *testing.T) { + reply := make([]interface{}, len(tt.reply)) + for i, v := range tt.reply { + reply[i] = []byte(v) + } - var reply []interface{} - for _, v := range tt.reply { - reply = append(reply, []byte(v)) - } - - value := reflect.New(reflect.ValueOf(tt.value).Type().Elem()) - - if err := redis.ScanStruct(reply, value.Interface()); err != nil { - t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err) - } - - if !reflect.DeepEqual(value.Interface(), tt.value) { - t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value) - } + value := reflect.New(reflect.ValueOf(tt.expected).Type().Elem()).Interface() + err := redis.ScanStruct(reply, value) + require.NoError(t, err) + require.Equal(t, tt.expected, value) + }) } } @@ -486,6 +518,16 @@ type Edp struct { EdpI int `redis:"edpi"` } +type Ed2 struct { + Ed2I int `redis:"ed2i"` + *Edp2 +} + +type Edp2 struct { + Edp2I int `redis:"edp2i"` + *Edp +} + var argsTests = []struct { title string actual redis.Args @@ -493,19 +535,20 @@ var argsTests = []struct { }{ {"struct-ptr", redis.Args{}.AddFlat(&struct { - I int `redis:"i"` - U uint `redis:"u"` - S string `redis:"s"` - P []byte `redis:"p"` - M map[string]string `redis:"m"` - Bt bool - Bf bool - PtrB *bool - PtrI *int + I int `redis:"i"` + U uint `redis:"u"` + S string `redis:"s"` + P []byte `redis:"p"` + M map[string]string `redis:"m"` + Bt bool + Bf bool + PtrB *bool + PtrI *int + PtrI2 *int }{ - -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, + -1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5, }), - redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true}, + redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5}, }, {"struct", redis.Args{}.AddFlat(struct{ I int }{123}), @@ -545,14 +588,45 @@ var argsTests = []struct { }), redis.Args{"edi", 2, "edpi", 3}, }, + {"struct-anonymous-nil", + redis.Args{}.AddFlat(struct { + Ed + *Edp + }{ + Ed: Ed{EdI: 2}, + }), + redis.Args{"edi", 2}, + }, + {"struct-anonymous-multi-nil-early", + redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + }), + redis.Args{"edi", 2}, + }, + {"struct-anonymous-multi-nil-late", + redis.Args{}.AddFlat(struct { + Ed + *Ed2 + }{ + Ed: Ed{EdI: 2}, + Ed2: &Ed2{ + Ed2I: 3, + Edp2: &Edp2{ + Edp2I: 4, + }, + }, + }), + redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4}, + }, } func TestArgs(t *testing.T) { for _, tt := range argsTests { t.Run(tt.title, func(t *testing.T) { - if !reflect.DeepEqual(tt.actual, tt.expected) { - t.Fatalf("is %v, want %v", tt.actual, tt.expected) - } + require.Equal(t, tt.expected, tt.actual) }) } }