Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update reflectx to allow for optional nested structs #900

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
188 changes: 188 additions & 0 deletions reflectx/reflect.go
Expand Up @@ -7,8 +7,11 @@
package reflectx

import (
"database/sql"
"fmt"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
)
Expand Down Expand Up @@ -201,6 +204,191 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in
return nil
}

// ObjectContext provides a single layer to abstract away
// nested struct scanning functionality
type ObjectContext struct {
value reflect.Value
}

func NewObjectContext() *ObjectContext {
return &ObjectContext{}
}

// NewRow updates the object reference.
// This ensures all columns point to the same object
func (o *ObjectContext) NewRow(value reflect.Value) {
o.value = value
}

// FieldForIndexes returns the value for address. If the address is a nested struct,
// a nestedFieldScanner is returned instead of the standard value reference
func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value {
if len(indexes) == 1 {
val := FieldByIndexes(o.value, indexes)
return val
}

obj := &nestedFieldScanner{
parent: o,
indexes: indexes,
}

v := reflect.ValueOf(obj).Elem()
return v
}

// nestedFieldScanner will only forward the Scan to the nested value if
// the database value is not nil.
type nestedFieldScanner struct {
parent *ObjectContext
indexes []int
}

// Scan implements sql.Scanner.
// This method largely mirrors the sql.convertAssign() method with some minor changes
func (o *nestedFieldScanner) Scan(src interface{}) error {
if src == nil {
return nil
}

dv := FieldByIndexes(o.parent.value, o.indexes)
// Dereference pointer fields to avoid double pointers **T
if dv.Kind() == reflect.Pointer {
dv.Set(reflect.New(dv.Type().Elem()))
dv = dv.Elem()
}
iface := dv.Addr().Interface()

if scan, ok := iface.(sql.Scanner); ok {
return scan.Scan(src)
}

sv := reflect.ValueOf(src)

// below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
// with a few minor edits

if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(bytesClone(b)))
default:
dv.Set(sv)
}

return nil
}

if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}

// The following conversions use a string value as an intermediate representation
// to convert between various numeric types.
//
// This also allows scanning into user defined types such as "type Int int64".
// For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
switch v := src.(type) {
case string:
dv.SetString(v)
return nil
case []byte:
dv.SetString(string(v))
return nil
}
}

return fmt.Errorf("don't know how to parse type %T -> %T", src, iface)
}

// returns internal conversion error if available
// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}

// converts value to it's string value
// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go
func asString(src interface{}) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}

// bytesClone returns a copy of b[:len(b)].
// The result may have additional unused capacity.
// Clone(nil) returns nil.
//
// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version
func bytesClone(b []byte) []byte {
if b == nil {
return nil
}
return append([]byte{}, b...)
}

// FieldByIndexes returns a value for the field given by the struct traversal
// for the given value.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
Expand Down
16 changes: 11 additions & 5 deletions sqlx.go
Expand Up @@ -621,7 +621,8 @@ func (r *Rows) StructScan(dest interface{}) error {
r.started = true
}

err := fieldsByTraversal(v, r.fields, r.values, true)
octx := reflectx.NewObjectContext()
err := fieldsByTraversal(octx, v, r.fields, r.values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -781,7 +782,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
}
values := make([]interface{}, len(columns))

err = fieldsByTraversal(v, fields, values, true)
octx := reflectx.NewObjectContext()

err = fieldsByTraversal(octx, v, fields, values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -948,13 +951,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
values = make([]interface{}, len(columns))
octx := reflectx.NewObjectContext()

for rows.Next() {
// create a new struct type (which returns PtrTo) and indirect it
vp = reflect.New(base)
v = reflect.Indirect(vp)

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(octx, v, fields, values, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -1020,18 +1024,20 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}

octx.NewRow(v)

for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
f := octx.FieldForIndexes(traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
Expand Down
104 changes: 104 additions & 0 deletions sqlx_context_test.go
Expand Up @@ -642,6 +642,110 @@ func TestNamedQueryContext(t *testing.T) {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID)
}
}

rows.Close()

type Owner struct {
Email *string `db:"email"`
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
}

// Test optional nested structs with left join
type PlaceOwner struct {
Place Place `db:"place"`
Owner *Owner `db:"owner"`
}

pl = Place{
Name: sql.NullString{String: "the-house", Valid: true},
}

q4 := `INSERT INTO place (id, name) VALUES (2, :name)`
_, err = db.NamedExecContext(ctx, q4, pl)
if err != nil {
log.Fatal(err)
}

id = 2
pp.Place.ID = id

q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`
_, err = db.NamedExecContext(ctx, q5, pp)
if err != nil {
log.Fatal(err)
}

pp3 := &PlaceOwner{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email",
place.id AS "place.id",
place.name AS "place.name"
FROM place
LEFT JOIN placeperson ON false -- null left join
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp3)
if err != nil {
t.Error(err)
}
if pp3.Owner != nil {
t.Error("Expected `Owner`, to be nil")
}
if pp3.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String)
}
if pp3.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID)
}
}

rows.Close()

pp3 = &PlaceOwner{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email",
place.id AS "place.id",
place.name AS "place.name"
FROM place
left JOIN placeperson ON placeperson.place_id = place.id
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp3)
if err != nil {
t.Error(err)
}
if pp3.Owner == nil {
t.Error("Expected `Owner`, to not be nil")
}

if pp3.Owner.FirstName != "ben" {
t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName)
}
if pp3.Owner.LastName != "doe" {
t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName)
}
if pp3.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String)
}
if pp3.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID)
}
}
})
}

Expand Down