Skip to content

Commit

Permalink
fix: Anonymous field pointers
Browse files Browse the repository at this point in the history
Fix panic using ArgsFlat or ScanStruct on structs with nil Anonymous
field pointers.

Also:
* Remove unused travis-ci link from README.

Fixes: #621
  • Loading branch information
stevenh committed Jul 1, 2022
1 parent 5b789c6 commit e1fef94
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 40 deletions.
1 change: 0 additions & 1 deletion README.markdown
@@ -1,7 +1,6 @@
Redigo
======

[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo)
[![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis)

Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
Expand Down
48 changes: 48 additions & 0 deletions redis/reflect.go
@@ -0,0 +1,48 @@
package redis

import (
"reflect"
"runtime"
)

// methodName returns the name of the calling method,
// assumed to be two stack frames above.
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}

// mustBe panics if f's kind is not expected.
func mustBe(v reflect.Value, expected reflect.Kind) {
if v.Kind() != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()})
}
}

// fieldByIndexCreate returns the nested field corresponding
// to index creating elements that are nil when stepping through.
// It panics if v is not a struct.
func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v
}
34 changes: 34 additions & 0 deletions redis/reflect_go117.go
@@ -0,0 +1,34 @@
//go:build go1.17 && !go1.18
// +build go1.17,!go1.18

package redis

import (
"errors"
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
if len(index) == 1 {
return v.Field(index[0]), nil
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name())
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v, nil
}
16 changes: 16 additions & 0 deletions redis/reflect_go118.go
@@ -0,0 +1,16 @@
//go:build go1.18
// +build go1.18

package redis

import (
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
return v.FieldByIndexErr(index)
}
21 changes: 14 additions & 7 deletions redis/scan.go
Expand Up @@ -376,9 +376,7 @@ LOOP:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")

var (
p string
)
var p string
first := true
for len(tag) > 0 {
i := strings.IndexByte(tag, ',')
Expand All @@ -402,10 +400,12 @@ LOOP:
}
}
}

d, found := depth[fs.name]
if !found {
d = 1 << 30
}

switch {
case len(index) == d:
// At same depth, remove from result.
Expand Down Expand Up @@ -436,7 +436,6 @@ var (
)

func structSpecForType(t reflect.Type) *structSpec {

structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
Expand Down Expand Up @@ -480,30 +479,34 @@ func ScanStruct(src []interface{}, dest interface{}) error {
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanStructValue
}

d = d.Elem()
if d.Kind() != reflect.Struct {
return errScanStructValue
}
ss := structSpecForType(d.Type())

if len(src)%2 != 0 {
return errors.New("redigo.ScanStruct: number of values not a multiple of 2")
}

ss := structSpecForType(d.Type())
for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}

name, ok := src[i].([]byte)
if !ok {
return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i)
}

fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {

if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil {
return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err)
}
}
Expand Down Expand Up @@ -648,7 +651,11 @@ func (args Args) AddFlat(v interface{}) Args {
func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
fv, err := fieldByIndexErr(v, fs.index)
if err != nil {
// Nil item ignore.
continue
}
if fs.omitEmpty {
var empty = false
switch fv.Kind() {
Expand Down
138 changes: 106 additions & 32 deletions redis/scan_test.go
Expand Up @@ -233,12 +233,15 @@ type s1 struct {
Sdp *durationScan `redis:"sdp"`
}

var boolTrue = true
var (
boolTrue = true
int5 = 5
)

var scanStructTests = []struct {
title string
reply []string
value interface{}
name string
reply []string
expected interface{}
}{
{"basic",
[]string{
Expand Down Expand Up @@ -273,25 +276,54 @@ var scanStructTests = []struct {
[]string{},
&s1{},
},
{"struct-anonymous-nil",
[]string{"edi", "2"},
&struct {
Ed
*Edp
}{
Ed: Ed{EdI: 2},
},
},
{"struct-anonymous-multi-nil-early",
[]string{"edi", "2"},
&struct {
Ed
*Ed2
}{
Ed: Ed{EdI: 2},
},
},
{"struct-anonymous-multi-nil-late",
[]string{"edi", "2", "ed2i", "3", "edp2i", "4"},
&struct {
Ed
*Ed2
}{
Ed: Ed{EdI: 2},
Ed2: &Ed2{
Ed2I: 3,
Edp2: &Edp2{
Edp2I: 4,
},
},
},
},
}

func TestScanStruct(t *testing.T) {
for _, tt := range scanStructTests {
t.Run(tt.name, func(t *testing.T) {
reply := make([]interface{}, len(tt.reply))
for i, v := range tt.reply {
reply[i] = []byte(v)
}

var reply []interface{}
for _, v := range tt.reply {
reply = append(reply, []byte(v))
}

value := reflect.New(reflect.ValueOf(tt.value).Type().Elem())

if err := redis.ScanStruct(reply, value.Interface()); err != nil {
t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err)
}

if !reflect.DeepEqual(value.Interface(), tt.value) {
t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value)
}
value := reflect.New(reflect.ValueOf(tt.expected).Type().Elem()).Interface()
err := redis.ScanStruct(reply, value)
require.NoError(t, err)
require.Equal(t, tt.expected, value)
})
}
}

Expand Down Expand Up @@ -486,26 +518,37 @@ type Edp struct {
EdpI int `redis:"edpi"`
}

type Ed2 struct {
Ed2I int `redis:"ed2i"`
*Edp2
}

type Edp2 struct {
Edp2I int `redis:"edp2i"`
*Edp
}

var argsTests = []struct {
title string
actual redis.Args
expected redis.Args
}{
{"struct-ptr",
redis.Args{}.AddFlat(&struct {
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
M map[string]string `redis:"m"`
Bt bool
Bf bool
PtrB *bool
PtrI *int
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
M map[string]string `redis:"m"`
Bt bool
Bf bool
PtrB *bool
PtrI *int
PtrI2 *int
}{
-1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil,
-1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false, &boolTrue, nil, &int5,
}),
redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true},
redis.Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false, "PtrB", true, "PtrI2", 5},
},
{"struct",
redis.Args{}.AddFlat(struct{ I int }{123}),
Expand Down Expand Up @@ -545,14 +588,45 @@ var argsTests = []struct {
}),
redis.Args{"edi", 2, "edpi", 3},
},
{"struct-anonymous-nil",
redis.Args{}.AddFlat(struct {
Ed
*Edp
}{
Ed: Ed{EdI: 2},
}),
redis.Args{"edi", 2},
},
{"struct-anonymous-multi-nil-early",
redis.Args{}.AddFlat(struct {
Ed
*Ed2
}{
Ed: Ed{EdI: 2},
}),
redis.Args{"edi", 2},
},
{"struct-anonymous-multi-nil-late",
redis.Args{}.AddFlat(struct {
Ed
*Ed2
}{
Ed: Ed{EdI: 2},
Ed2: &Ed2{
Ed2I: 3,
Edp2: &Edp2{
Edp2I: 4,
},
},
}),
redis.Args{"edi", 2, "ed2i", 3, "edp2i", 4},
},
}

func TestArgs(t *testing.T) {
for _, tt := range argsTests {
t.Run(tt.title, func(t *testing.T) {
if !reflect.DeepEqual(tt.actual, tt.expected) {
t.Fatalf("is %v, want %v", tt.actual, tt.expected)
}
require.Equal(t, tt.expected, tt.actual)
})
}
}
Expand Down

0 comments on commit e1fef94

Please sign in to comment.