Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #657 incorrect encoding of embedded-recursive types and overlapping tags #659

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 36 additions & 3 deletions reflect_extension.go
Expand Up @@ -2,12 +2,13 @@ package jsoniter

import (
"fmt"
"github.com/modern-go/reflect2"
"reflect"
"sort"
"strings"
"unicode"
"unsafe"

"github.com/modern-go/reflect2"
)

var typeDecoders = map[string]ValDecoder{}
Expand Down Expand Up @@ -332,6 +333,10 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
}

func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
return _describeStruct(ctx, typ, nil)
}

func _describeStruct(ctx *ctx, typ reflect2.Type, parents []reflect2.Type) *StructDescriptor {
structType := typ.(*reflect2.UnsafeStructType)
embeddedBindings := []*Binding{}
bindings := []*Binding{}
Expand All @@ -347,7 +352,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
tagParts := strings.Split(tag, ",")
if field.Anonymous() && (tag == "" || tagParts[0] == "") {
if field.Type().Kind() == reflect.Struct {
structDescriptor := describeStruct(ctx, field.Type())
if isRecursive(parents, field.Type()) {
return nil
}
parents = append(parents, field.Type())

structDescriptor := _describeStruct(ctx, field.Type(), parents)
if structDescriptor == nil {
continue
}

for _, binding := range structDescriptor.Fields {
binding.levels = append([]int{i}, binding.levels...)
omitempty := binding.Encoder.(*structFieldEncoder).omitempty
Expand All @@ -359,7 +373,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
} else if field.Type().Kind() == reflect.Ptr {
ptrType := field.Type().(*reflect2.UnsafePtrType)
if ptrType.Elem().Kind() == reflect.Struct {
structDescriptor := describeStruct(ctx, ptrType.Elem())
if isRecursive(parents, field.Type()) {
return nil
}
parents = append(parents, field.Type())

structDescriptor := _describeStruct(ctx, ptrType.Elem(), parents)
if structDescriptor == nil {
continue
}

for _, binding := range structDescriptor.Fields {
binding.levels = append([]int{i}, binding.levels...)
omitempty := binding.Encoder.(*structFieldEncoder).omitempty
Expand Down Expand Up @@ -395,6 +418,16 @@ func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
}
return createStructDescriptor(ctx, typ, bindings, embeddedBindings)
}

func isRecursive(parents []reflect2.Type, parent reflect2.Type) bool {
for _, p := range parents {
if p == parent {
return true
}
}
return false
}

func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor {
structDescriptor := &StructDescriptor{
Type: typ,
Expand Down
39 changes: 24 additions & 15 deletions reflect_struct_decoder.go
Expand Up @@ -12,22 +12,15 @@ import (
func decoderOfStruct(ctx *ctx, typ reflect2.Type) ValDecoder {
bindings := map[string]*Binding{}
structDescriptor := describeStruct(ctx, typ)
for _, binding := range structDescriptor.Fields {
for _, fromName := range binding.FromNames {
old := bindings[fromName]
if old == nil {
bindings[fromName] = binding
continue
}
ignoreOld, ignoreNew := resolveConflictBinding(ctx.frozenConfig, old, binding)
if ignoreOld {
delete(bindings, fromName)
}
if !ignoreNew {
bindings[fromName] = binding
}
}

flattenedBindings := flattenFrom(structDescriptor.Fields, ctx.frozenConfig)

orderedBindings := resolveBindings(flattenedBindings)

for _, b := range orderedBindings {
bindings[b.name] = b.binding
}

fields := map[string]*structFieldDecoder{}
for k, binding := range bindings {
fields[k] = binding.Decoder.(*structFieldDecoder)
Expand All @@ -44,6 +37,22 @@ func decoderOfStruct(ctx *ctx, typ reflect2.Type) ValDecoder {
return createStructDecoder(ctx, typ, fields)
}

func flattenFrom(bindings []*Binding, cfg *frozenConfig) []*binding {
flattened := make([]*binding, 0, len(bindings))

for _, b := range bindings {
for _, fromName := range b.FromNames {
flattened = append(flattened, &binding{
binding: b,
name: fromName,
hasTag: hasTag(b, cfg),
})
}
}

return flattened
}

func createStructDecoder(ctx *ctx, typ reflect2.Type, fields map[string]*structFieldDecoder) ValDecoder {
if ctx.disallowUnknownFields {
return &generalStructDecoder{typ: typ, fields: fields, disallowUnknownFields: true}
Expand Down
131 changes: 104 additions & 27 deletions reflect_struct_encoder.go
Expand Up @@ -2,48 +2,125 @@ package jsoniter

import (
"fmt"
"github.com/modern-go/reflect2"
"io"
"reflect"
"sort"
"strings"
"unsafe"

"github.com/modern-go/reflect2"
)

type binding struct {
binding *Binding
name string
hasTag bool
}

func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder {
type bindingTo struct {
binding *Binding
toName string
ignored bool
}
orderedBindings := []*bindingTo{}

orderedBindings := []*binding{}
structDescriptor := describeStruct(ctx, typ)
for _, binding := range structDescriptor.Fields {
for _, toName := range binding.ToNames {
new := &bindingTo{
binding: binding,
toName: toName,
}
for _, old := range orderedBindings {
if old.toName != toName {
continue
}
old.ignored, new.ignored = resolveConflictBinding(ctx.frozenConfig, old.binding, new.binding)
}
orderedBindings = append(orderedBindings, new)
}
}

fields := flattenTo(structDescriptor.Fields, ctx.frozenConfig)

orderedBindings = resolveBindings(fields)

if len(orderedBindings) == 0 {
return &emptyStructEncoder{}
}

finalOrderedFields := []structFieldTo{}
for _, bindingTo := range orderedBindings {
if !bindingTo.ignored {
finalOrderedFields = append(finalOrderedFields, structFieldTo{
encoder: bindingTo.binding.Encoder.(*structFieldEncoder),
toName: bindingTo.toName,
finalOrderedFields = append(finalOrderedFields, structFieldTo{
encoder: bindingTo.binding.Encoder.(*structFieldEncoder),
toName: bindingTo.name,
})
}

return &structEncoder{typ, finalOrderedFields}
}

func flattenTo(bindings []*Binding, cfg *frozenConfig) []*binding {
flattened := make([]*binding, 0, len(bindings))

for _, b := range bindings {
for _, toName := range b.ToNames {
flattened = append(flattened, &binding{
binding: b,
name: toName,
hasTag: hasTag(b, cfg),
})
}
}
return &structEncoder{typ, finalOrderedFields}

return flattened
}

func hasTag(b *Binding, cfg *frozenConfig) bool {
before, _, _ := strings.Cut(b.Field.Tag().Get(cfg.getTagKey()), ",")
return before != ""
}

func resolveBindings(fields []*binding) []*binding {
sort.SliceStable(fields, func(i, j int) bool {
// As per std's encoding/json,
// it sorts fields by names, index depth(here we call it levels) and tags.
// We've already sorted fields by index order in describeStruct.
// By using stable sorting, we avoid sorting them again.
if fields[i].name != fields[j].name {
return fields[i].name < fields[j].name
}
if len(fields[i].binding.levels) != len(fields[j].binding.levels) {
return len(fields[i].binding.levels) < len(fields[j].binding.levels)
}
if fields[i].hasTag != fields[j].hasTag {
return fields[i].hasTag
}
return true // equal.
})

orderedBindings := trimOverlappingBindings(fields)

sort.Slice(orderedBindings, func(i, j int) bool {
left := orderedBindings[i].binding.levels
right := orderedBindings[j].binding.levels
k := 0
for {
if left[k] < right[k] {
return true
} else if left[k] > right[k] {
return false
}
k++
}
})

return orderedBindings
}

func trimOverlappingBindings(bindings []*binding) []*binding {
out := bindings[:0]
for nameRange, i := 0, 0; i < len(bindings); i += nameRange {
for nameRange = 1; i+nameRange < len(bindings); nameRange++ {
endOfRange := bindings[i+nameRange]
if endOfRange.name != bindings[i].name {
break
}
}
if nameRange == 1 { // only one field for that name
out = append(out, bindings[i])
} else {
fields := bindings[i : i+nameRange]
if len(fields[0].binding.levels) == len(fields[1].binding.levels) &&
fields[0].hasTag == fields[1].hasTag {
continue
}
out = append(out, fields[0])
}
}

return out
}

func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty {
Expand Down
42 changes: 42 additions & 0 deletions type_tests/struct_embedded_test.go
Expand Up @@ -61,6 +61,9 @@ func init() {
(*SameLevel2Tagged)(nil),
(*EmbeddedPtr)(nil),
(*UnnamedLiteral)(nil),
(*EmbeddedRecursive)(nil),
(*EmbeddedRecursive2)(nil),
(*EmbeddedRecursive3)(nil),
)
}

Expand Down Expand Up @@ -236,3 +239,42 @@ type EmbeddedPtr struct {
type UnnamedLiteral struct {
_ struct{}
}

type EmbeddedRecursive struct {
Foo string
Recursive1
}

type Recursive1 struct {
R string
Recursive2
}

type Recursive2 struct {
Foo string
R string
RR string
*EmbeddedRecursive
}

type EmbeddedRecursive2 struct {
Foo string
Recursive1
Recursive3
}

type Recursive3 struct {
Foo string
RR string
*EmbeddedRecursive2
}

type Recursive4 struct {
Bar string
Recursive3
}

type EmbeddedRecursive3 struct {
Foo string
*EmbeddedRecursive3
}
26 changes: 26 additions & 0 deletions type_tests/struct_tags_test.go
Expand Up @@ -149,6 +149,32 @@ func init() {
(*struct {
Field bool `json:"中文"`
})(nil),
(*struct {
Foo string `json:"Bar"`
Bar string
})(nil),
(*struct {
Foo string `json:"Bar"`
Bar string `json:"Foo"`
})(nil),
(*struct {
Foo string
Bar string `json:"Foo"`
})(nil),
(*struct {
Foo string `json:"Bar"`
Bar string `json:"Bar"`
})(nil),
(*struct {
Foo string `json:"Bar"`
Bar string
Baz string `json:"Bar"`
})(nil),
(*struct {
Foo string
F string
EmbeddedOmitEmptyE
})(nil),
)
}

Expand Down