Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change eindirect behave match with indirect from decode #358

Merged
merged 5 commits into from Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
136 changes: 69 additions & 67 deletions encode.go
Expand Up @@ -64,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
"\x7f", `\u007f`,
)

var (
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
)

// Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML.
type Marshaler interface {
Expand Down Expand Up @@ -154,12 +160,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
// If we can marshal the type to text, then we use that. This prevents the
// encoder for handling these types as generic structs (or whatever the
// underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case encoding.TextMarshaler, Marshaler:
switch {
case isMarshaler(rv):
enc.writeKeyValue(key, rv, false)
return
case Primitive: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(t.undecoded))
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
return
}

Expand Down Expand Up @@ -318,7 +324,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
elem := eindirect(rv.Index(i))
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
Expand All @@ -332,7 +338,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
trv := eindirect(rv.Index(i))
if isNil(trv) {
continue
}
Expand All @@ -357,7 +363,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
}

func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() {
switch rv.Kind() {
case reflect.Map:
enc.eMap(key, rv, inline)
case reflect.Struct:
Expand All @@ -379,7 +385,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) {
if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
Expand All @@ -389,7 +395,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys)
for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey))
val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey)))
if isNil(val) {
continue
}
Expand Down Expand Up @@ -417,6 +423,13 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {

const is32Bit = (32 << (^uint(0) >> 63)) == 32

func pointerTo(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return pointerTo(t.Elem())
}
return t
}

func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table then all keys under it will be in that
Expand All @@ -433,35 +446,25 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
if opts.skip {
continue
}

frv := rv.Field(i)
frv := eindirect(rv.Field(i))

// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
if getOptions(f.Tag).name == "" {
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
if isEmbed {
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
addFields(frv.Type(), frv, append(start, f.Index...))
continue
}
}

Expand All @@ -487,7 +490,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex)
fieldVal := eindirect(rv.FieldByIndex(fieldIndex))

if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
Expand Down Expand Up @@ -540,6 +543,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}

if rv.Kind() == reflect.Struct {
if rv.Type() == timeType {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
}

if isMarshaler(rv) {
return tomlString
}

switch rv.Kind() {
case reflect.Bool:
return tomlBool
Expand All @@ -561,42 +579,14 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
if _, ok := rv.Interface().(time.Time); ok {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
default:
if isMarshaler(rv) {
return tomlString
}

encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("unreachable")
}
}

func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) {
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}

// Someone used a pointer receiver: we can make it work for pointer values.
if rv.CanAddr() {
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
return true
}
if _, ok := rv.Addr().Interface().(Marshaler); ok {
return true
}
}
return false
return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml)
}

// isTableArray reports if all entries in the array or slice are a table.
Expand All @@ -605,19 +595,19 @@ func isTableArray(arr reflect.Value) bool {
return false
}

/// Don't allow nil.
ret := true
for i := 0; i < arr.Len(); i++ {
if tomlTypeOfGo(arr.Index(i)) == nil {
tt := tomlTypeOfGo(eindirect(arr.Index(i)))
// Don't allow nil.
if tt == nil {
encPanic(errArrayNilElement)
}
}

for i := 0; i < arr.Len(); i++ {
if !typeEqual(tomlHash, tomlTypeOfGo(arr.Index(i))) {
return false
if ret && !typeEqual(tomlHash, tt) {
ret = false
}
}
return true
return ret
}

type tagOptions struct {
Expand Down Expand Up @@ -715,13 +705,25 @@ func encPanic(err error) {
panic(tomlEncodeError{err})
}

// Resolve any level of pointers to the actual value (e.g. **string → string).
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface {
if isMarshaler(v) {
return v
}
if v.CanAddr() { /// Special case for marshalers; see #358.
if pv := v.Addr(); isMarshaler(pv) {
return pv
}
}
return v
}

if v.IsNil() {
return v
}

return eindirect(v.Elem())
}

func isNil(rv reflect.Value) bool {
Expand Down