From 88580309d81a096d6cc9ac437b8a75b1eb824aaf Mon Sep 17 00:00:00 2001 From: Fs02 Date: Sat, 4 Jun 2022 11:10:09 +0900 Subject: [PATCH] Add support for scanning joined assoc --- document.go | 48 ++++++++++++++++++++++++++++++++++++++++++++--- document_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/document.go b/document.go index 19f4a271..2ef4975a 100644 --- a/document.go +++ b/document.go @@ -249,7 +249,11 @@ func (d Document) SetValue(field string, value interface{}) bool { // Scanners returns slice of sql.Scanner for given fields. func (d Document) Scanners(fields []string) []interface{} { var ( - result = make([]interface{}, len(fields)) + result = make([]interface{}, len(fields)) + assocRefs map[string]struct { + fields []string + indexes []int + } ) for index, field := range fields { @@ -264,11 +268,41 @@ func (d Document) Scanners(fields []string) []interface{} { } else { result[index] = Nullable(fv.Addr().Interface()) } + } else if split := strings.SplitN(field, ".", 2); len(split) == 2 { + if assocRefs == nil { + assocRefs = make(map[string]struct { + fields []string + indexes []int + }) + } + + refs := assocRefs[split[0]] + refs.fields = append(refs.fields, split[1]) + refs.indexes = append(refs.indexes, index) + assocRefs[split[0]] = refs } else { result[index] = &sql.RawBytes{} } } + // get scanners from associations + for assocName, refs := range assocRefs { + if assoc, ok := d.association(assocName); ok && assoc.Type() == BelongsTo || assoc.Type() == HasOne { + var ( + assocDoc, _ = assoc.Document() + assocScanners = assocDoc.Scanners(refs.fields) + ) + + for i, index := range refs.indexes { + result[index] = assocScanners[i] + } + } else { + for _, index := range refs.indexes { + result[index] = &sql.RawBytes{} + } + } + } + return result } @@ -294,12 +328,20 @@ func (d Document) Preload() []string { // Association of this document with given name. func (d Document) Association(name string) Association { + if assoc, ok := d.association(name); ok { + return assoc + } + + panic("rel: no field named (" + name + ") in type " + d.rt.String() + " found ") +} + +func (d Document) association(name string) (Association, bool) { index, ok := d.data.index[name] if !ok { - panic("rel: no field named (" + name + ") in type " + d.rt.String() + " found ") + return Association{}, false } - return newAssociation(d.rv, index) + return newAssociation(d.rv, index), true } // Reset this document, this is a noop for compatibility with collection. diff --git a/document_test.go b/document_test.go index 58176f2b..b8a19749 100644 --- a/document_test.go +++ b/document_test.go @@ -474,6 +474,55 @@ func TestDocument_Scanners(t *testing.T) { assert.Equal(t, scanners, doc.Scanners(fields)) } +func TestDocument_Scanners_withAssoc(t *testing.T) { + var ( + record = Transaction{ + ID: 1, + BuyerID: 2, + Status: "SENT", + Buyer: User{ + ID: 2, + Name: "user", + WorkAddress: &Address{ + Street: "Takeshita-dori", + }, + }, + } + doc = NewDocument(&record) + fields = []string{"id", "user_id", "buyer.id", "buyer.name", "buyer.work_address.street", "status", "invalid_assoc.id"} + scanners = []interface{}{ + Nullable(&record.ID), + Nullable(&record.BuyerID), + Nullable(&record.Buyer.ID), + Nullable(&record.Buyer.Name), + Nullable(&record.Buyer.WorkAddress.Street), + Nullable(&record.Status), + &sql.RawBytes{}, + } + ) + + assert.Equal(t, scanners, doc.Scanners(fields)) +} + +func TestDocument_Scanners_withUnitializedAssoc(t *testing.T) { + var ( + record = Transaction{} + doc = NewDocument(&record) + fields = []string{"id", "user_id", "buyer.id", "buyer.name", "status", "buyer.work_address.street"} + result = doc.Scanners(fields) + expected = []interface{}{ + Nullable(&record.ID), + Nullable(&record.BuyerID), + Nullable(&record.Buyer.ID), + Nullable(&record.Buyer.Name), + Nullable(&record.Status), + Nullable(&record.Buyer.WorkAddress.Street), + } + ) + + assert.Equal(t, expected, result) +} + func TestDocument_ScannersInitPointers(t *testing.T) { type Embedded1 struct { ID int