Skip to content

Commit

Permalink
Allow MarshalTOML and MarshalText to be used on the document type itself
Browse files Browse the repository at this point in the history
Fixes #383
  • Loading branch information
arp242 committed May 19, 2023
1 parent 2967a1e commit d56d9f6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
8 changes: 7 additions & 1 deletion encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ func NewEncoder(w io.Writer) *Encoder {
// document.
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))

// XXX

if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
Expand Down Expand Up @@ -693,8 +696,11 @@ func (enc *Encoder) newline() {
// v v v v vv
// key = {k = 1, k2 = 2}
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
/// Marshaler used on top-level document; call eElement() to just call
/// Marshal{TOML,Text}.
if len(key) == 0 {
encPanic(errNoKey)
enc.eElement(val)
return
}
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
Expand Down
37 changes: 37 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,43 @@ c = 3
}
}

type (
Doc1 struct{ N string }
Doc2 struct{ N string }
)

func (d Doc1) MarshalTOML() ([]byte, error) { return []byte(`marshal_toml = "` + d.N + `"`), nil }
func (d Doc2) MarshalText() ([]byte, error) { return []byte(`marshal_text = "` + d.N + `"`), nil }

// MarshalTOML and MarshalText on the top level type, rather than a field.
func TestMarshalDoc(t *testing.T) {
t.Run("toml", func(t *testing.T) {
var buf bytes.Buffer
err := NewEncoder(&buf).Encode(Doc1{"asd"})
if err != nil {
t.Fatal(err)
}

want := `marshal_toml = "asd"`
if want != buf.String() {
t.Errorf("\nhave: %s\nwant: %s\n", buf.String(), want)
}
})

t.Run("text", func(t *testing.T) {
var buf bytes.Buffer
err := NewEncoder(&buf).Encode(Doc2{"asd"})
if err != nil {
t.Fatal(err)
}

want := `"marshal_text = \"asd\""`
if want != buf.String() {
t.Errorf("\nhave: %s\nwant: %s\n", buf.String(), want)
}
})
}

func encodeExpected(t *testing.T, label string, val interface{}, want string, wantErr error) {
t.Helper()
t.Run(label, func(t *testing.T) {
Expand Down

0 comments on commit d56d9f6

Please sign in to comment.