Skip to content

Commit

Permalink
fix: Anonymous field pointers (#622)
Browse files Browse the repository at this point in the history
Fix panic using ArgsFlat or ScanStruct on structs with nil Anonymous
field pointers.

Catch the anonymous struct recursion and prevent it. In the case of
ScanStruct an error will be returned, in the case of ArgsFlat it will
panic with a nice error.
  • Loading branch information
stevenh committed Jul 6, 2022
1 parent d3b4cc3 commit f1e923c
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 71 deletions.
1 change: 0 additions & 1 deletion README.markdown
@@ -1,7 +1,6 @@
Redigo
======

[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo)
[![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis)

Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
Expand Down
48 changes: 48 additions & 0 deletions redis/reflect.go
@@ -0,0 +1,48 @@
package redis

import (
"reflect"
"runtime"
)

// methodName returns the name of the calling method,
// assumed to be two stack frames above.
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}

// mustBe panics if f's kind is not expected.
func mustBe(v reflect.Value, expected reflect.Kind) {
if v.Kind() != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()})
}
}

// fieldByIndexCreate returns the nested field corresponding
// to index creating elements that are nil when stepping through.
// It panics if v is not a struct.
func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v
}
34 changes: 34 additions & 0 deletions redis/reflect_go117.go
@@ -0,0 +1,34 @@
//go:build !go1.18
// +build !go1.18

package redis

import (
"errors"
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
if len(index) == 1 {
return v.Field(index[0]), nil
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name())
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v, nil
}
16 changes: 16 additions & 0 deletions redis/reflect_go118.go
@@ -0,0 +1,16 @@
//go:build go1.18
// +build go1.18

package redis

import (
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
return v.FieldByIndexErr(index)
}
69 changes: 51 additions & 18 deletions redis/scan.go
Expand Up @@ -355,7 +355,13 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}

func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec, seen map[reflect.Type]struct{}) error {
if _, ok := seen[t]; ok {
// Protect against infinite recursion.
return fmt.Errorf("recursive struct definition for %v", t)
}

seen[t] = struct{}{}
LOOP:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
Expand All @@ -365,20 +371,21 @@ LOOP:
case f.Anonymous:
switch f.Type.Kind() {
case reflect.Struct:
compileStructSpec(f.Type, depth, append(index, i), ss)
if err := compileStructSpec(f.Type, depth, append(index, i), ss, seen); err != nil {
return err
}
case reflect.Ptr:
// TODO(steve): Protect against infinite recursion.
if f.Type.Elem().Kind() == reflect.Struct {
compileStructSpec(f.Type.Elem(), depth, append(index, i), ss)
if err := compileStructSpec(f.Type.Elem(), depth, append(index, i), ss, seen); err != nil {
return err
}
}
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")

var (
p string
)
var p string
first := true
for len(tag) > 0 {
i := strings.IndexByte(tag, ',')
Expand All @@ -402,10 +409,12 @@ LOOP:
}
}
}

d, found := depth[fs.name]
if !found {
d = 1 << 30
}

switch {
case len(index) == d:
// At same depth, remove from result.
Expand All @@ -428,33 +437,36 @@ LOOP:
}
}
}

return nil
}

var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
)

func structSpecForType(t reflect.Type) *structSpec {

func structSpecForType(t reflect.Type) (*structSpec, error) {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
return ss, nil
}

structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
return ss, nil
}

ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
if err := compileStructSpec(t, make(map[string]int), nil, ss, make(map[reflect.Type]struct{})); err != nil {
return nil, fmt.Errorf("compile struct: %s: %w", t, err)
}
structSpecCache[t] = ss
return ss
return ss, nil
}

var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct")
Expand All @@ -480,30 +492,38 @@ func ScanStruct(src []interface{}, dest interface{}) error {
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanStructValue
}

d = d.Elem()
if d.Kind() != reflect.Struct {
return errScanStructValue
}
ss := structSpecForType(d.Type())

if len(src)%2 != 0 {
return errors.New("redigo.ScanStruct: number of values not a multiple of 2")
}

ss, err := structSpecForType(d.Type())
if err != nil {
return fmt.Errorf("redigo.ScanStruct: %w", err)
}

for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}

name, ok := src[i].([]byte)
if !ok {
return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i)
}

fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {

if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil {
return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err)
}
}
Expand Down Expand Up @@ -555,7 +575,11 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error
return nil
}

ss := structSpecForType(t)
ss, err := structSpecForType(t)
if err != nil {
return fmt.Errorf("redigo.ScanSlice: %w", err)
}

fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
Expand Down Expand Up @@ -618,6 +642,7 @@ func (args Args) Add(value ...interface{}) Args {
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
// panics if v includes a recursive anonymous struct.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
Expand Down Expand Up @@ -646,9 +671,17 @@ func (args Args) AddFlat(v interface{}) Args {
}

func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
ss, err := structSpecForType(v.Type())
if err != nil {
panic(fmt.Errorf("redigo.AddFlat: %w", err))
}

for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
fv, err := fieldByIndexErr(v, fs.index)
if err != nil {
// Nil item ignore.
continue
}
if fs.omitEmpty {
var empty = false
switch fv.Kind() {
Expand Down

0 comments on commit f1e923c

Please sign in to comment.