Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scanning joined assoc #290

Merged
merged 1 commit into from Jun 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 45 additions & 3 deletions document.go
Expand Up @@ -249,7 +249,11 @@ func (d Document) SetValue(field string, value interface{}) bool {
// Scanners returns slice of sql.Scanner for given fields.
func (d Document) Scanners(fields []string) []interface{} {
var (
result = make([]interface{}, len(fields))
result = make([]interface{}, len(fields))
assocRefs map[string]struct {
fields []string
indexes []int
}
)

for index, field := range fields {
Expand All @@ -264,11 +268,41 @@ func (d Document) Scanners(fields []string) []interface{} {
} else {
result[index] = Nullable(fv.Addr().Interface())
}
} else if split := strings.SplitN(field, ".", 2); len(split) == 2 {
if assocRefs == nil {
assocRefs = make(map[string]struct {
fields []string
indexes []int
})
}

refs := assocRefs[split[0]]
refs.fields = append(refs.fields, split[1])
refs.indexes = append(refs.indexes, index)
assocRefs[split[0]] = refs
} else {
result[index] = &sql.RawBytes{}
}
}

// get scanners from associations
for assocName, refs := range assocRefs {
if assoc, ok := d.association(assocName); ok && assoc.Type() == BelongsTo || assoc.Type() == HasOne {
var (
assocDoc, _ = assoc.Document()
assocScanners = assocDoc.Scanners(refs.fields)
)

for i, index := range refs.indexes {
result[index] = assocScanners[i]
}
} else {
for _, index := range refs.indexes {
result[index] = &sql.RawBytes{}
}
}
}

return result
}

Expand All @@ -294,12 +328,20 @@ func (d Document) Preload() []string {

// Association of this document with given name.
func (d Document) Association(name string) Association {
if assoc, ok := d.association(name); ok {
return assoc
}

panic("rel: no field named (" + name + ") in type " + d.rt.String() + " found ")
}

func (d Document) association(name string) (Association, bool) {
index, ok := d.data.index[name]
if !ok {
panic("rel: no field named (" + name + ") in type " + d.rt.String() + " found ")
return Association{}, false
}

return newAssociation(d.rv, index)
return newAssociation(d.rv, index), true
}

// Reset this document, this is a noop for compatibility with collection.
Expand Down
49 changes: 49 additions & 0 deletions document_test.go
Expand Up @@ -474,6 +474,55 @@ func TestDocument_Scanners(t *testing.T) {
assert.Equal(t, scanners, doc.Scanners(fields))
}

func TestDocument_Scanners_withAssoc(t *testing.T) {
var (
record = Transaction{
ID: 1,
BuyerID: 2,
Status: "SENT",
Buyer: User{
ID: 2,
Name: "user",
WorkAddress: &Address{
Street: "Takeshita-dori",
},
},
}
doc = NewDocument(&record)
fields = []string{"id", "user_id", "buyer.id", "buyer.name", "buyer.work_address.street", "status", "invalid_assoc.id"}
scanners = []interface{}{
Nullable(&record.ID),
Nullable(&record.BuyerID),
Nullable(&record.Buyer.ID),
Nullable(&record.Buyer.Name),
Nullable(&record.Buyer.WorkAddress.Street),
Nullable(&record.Status),
&sql.RawBytes{},
}
)

assert.Equal(t, scanners, doc.Scanners(fields))
}

func TestDocument_Scanners_withUnitializedAssoc(t *testing.T) {
var (
record = Transaction{}
doc = NewDocument(&record)
fields = []string{"id", "user_id", "buyer.id", "buyer.name", "status", "buyer.work_address.street"}
result = doc.Scanners(fields)
expected = []interface{}{
Nullable(&record.ID),
Nullable(&record.BuyerID),
Nullable(&record.Buyer.ID),
Nullable(&record.Buyer.Name),
Nullable(&record.Status),
Nullable(&record.Buyer.WorkAddress.Street),
}
)

assert.Equal(t, expected, result)
}

func TestDocument_ScannersInitPointers(t *testing.T) {
type Embedded1 struct {
ID int
Expand Down