Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Anonymous field pointers #622

Merged
merged 3 commits into from Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion README.markdown
@@ -1,7 +1,6 @@
Redigo
======

[![Build Status](https://travis-ci.org/gomodule/redigo.svg?branch=master)](https://travis-ci.org/gomodule/redigo)
[![GoDoc](https://godoc.org/github.com/gomodule/redigo/redis?status.svg)](https://pkg.go.dev/github.com/gomodule/redigo/redis)

Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
Expand Down
48 changes: 48 additions & 0 deletions redis/reflect.go
@@ -0,0 +1,48 @@
package redis

import (
"reflect"
"runtime"
)

// methodName returns the name of the calling method,
// assumed to be two stack frames above.
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}

// mustBe panics if f's kind is not expected.
func mustBe(v reflect.Value, expected reflect.Kind) {
if v.Kind() != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: v.Kind()})
}
}

// fieldByIndexCreate returns the nested field corresponding
// to index creating elements that are nil when stepping through.
// It panics if v is not a struct.
func fieldByIndexCreate(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v
}
34 changes: 34 additions & 0 deletions redis/reflect_go117.go
@@ -0,0 +1,34 @@
//go:build !go1.18
// +build !go1.18

package redis

import (
"errors"
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
if len(index) == 1 {
return v.Field(index[0]), nil
}

mustBe(v, reflect.Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
return reflect.Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.Type().Elem().Name())
}
v = v.Elem()
}
}
v = v.Field(x)
}

return v, nil
}
16 changes: 16 additions & 0 deletions redis/reflect_go118.go
@@ -0,0 +1,16 @@
//go:build go1.18
// +build go1.18

package redis

import (
"reflect"
)

// fieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func fieldByIndexErr(v reflect.Value, index []int) (reflect.Value, error) {
return v.FieldByIndexErr(index)
}
69 changes: 51 additions & 18 deletions redis/scan.go
Expand Up @@ -355,7 +355,13 @@ func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}

func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec, seen map[reflect.Type]struct{}) error {
if _, ok := seen[t]; ok {
// Protect against infinite recursion.
return fmt.Errorf("recursive struct definition for %v", t)
}

seen[t] = struct{}{}
LOOP:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
Expand All @@ -365,20 +371,21 @@ LOOP:
case f.Anonymous:
switch f.Type.Kind() {
case reflect.Struct:
compileStructSpec(f.Type, depth, append(index, i), ss)
if err := compileStructSpec(f.Type, depth, append(index, i), ss, seen); err != nil {
return err
}
case reflect.Ptr:
// TODO(steve): Protect against infinite recursion.
if f.Type.Elem().Kind() == reflect.Struct {
compileStructSpec(f.Type.Elem(), depth, append(index, i), ss)
if err := compileStructSpec(f.Type.Elem(), depth, append(index, i), ss, seen); err != nil {
return err
}
}
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")

var (
p string
)
var p string
first := true
for len(tag) > 0 {
i := strings.IndexByte(tag, ',')
Expand All @@ -402,10 +409,12 @@ LOOP:
}
}
}

d, found := depth[fs.name]
if !found {
d = 1 << 30
}

switch {
case len(index) == d:
// At same depth, remove from result.
Expand All @@ -428,33 +437,36 @@ LOOP:
}
}
}

return nil
}

var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
)

func structSpecForType(t reflect.Type) *structSpec {

func structSpecForType(t reflect.Type) (*structSpec, error) {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
return ss, nil
}

structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
return ss, nil
}

ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
if err := compileStructSpec(t, make(map[string]int), nil, ss, make(map[reflect.Type]struct{})); err != nil {
return nil, fmt.Errorf("compile struct: %s: %w", t, err)
}
structSpecCache[t] = ss
return ss
return ss, nil
}

var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct")
Expand All @@ -480,30 +492,38 @@ func ScanStruct(src []interface{}, dest interface{}) error {
if d.Kind() != reflect.Ptr || d.IsNil() {
return errScanStructValue
}

d = d.Elem()
if d.Kind() != reflect.Struct {
return errScanStructValue
}
ss := structSpecForType(d.Type())

if len(src)%2 != 0 {
return errors.New("redigo.ScanStruct: number of values not a multiple of 2")
}

ss, err := structSpecForType(d.Type())
if err != nil {
return fmt.Errorf("redigo.ScanStruct: %w", err)
}

for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}

name, ok := src[i].([]byte)
if !ok {
return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i)
}

fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {

if err := convertAssignValue(fieldByIndexCreate(d, fs.index), s); err != nil {
return fmt.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err)
}
}
Expand Down Expand Up @@ -555,7 +575,11 @@ func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error
return nil
}

ss := structSpecForType(t)
ss, err := structSpecForType(t)
if err != nil {
return fmt.Errorf("redigo.ScanSlice: %w", err)
}

fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
Expand Down Expand Up @@ -618,6 +642,7 @@ func (args Args) Add(value ...interface{}) Args {
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
// panics if v includes a recursive anonymous struct.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
Expand Down Expand Up @@ -646,9 +671,17 @@ func (args Args) AddFlat(v interface{}) Args {
}

func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
ss, err := structSpecForType(v.Type())
if err != nil {
panic(fmt.Errorf("redigo.AddFlat: %w", err))
}

for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
fv, err := fieldByIndexErr(v, fs.index)
if err != nil {
// Nil item ignore.
continue
}
if fs.omitEmpty {
var empty = false
switch fv.Kind() {
Expand Down