Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuzzing setup and fixes #755

Merged
merged 18 commits into from Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
@@ -1,3 +1,4 @@
* text=auto

benchmark/benchmark.toml text eol=lf
testdata/** text eol=lf
18 changes: 13 additions & 5 deletions ci.sh
Expand Up @@ -76,7 +76,8 @@ cover() {
fi

pushd "$dir"
go test -covermode=atomic -coverprofile=coverage.out ./...
go test -covermode=atomic -coverpkg=./... -coverprofile=coverage.out.tmp ./...
cat coverage.out.tmp | grep -v testsuite | grep -v tomltestgen | grep -v gotoml-test-decoder > coverage.out
go tool cover -func=coverage.out
popd

Expand All @@ -103,16 +104,23 @@ coverage() {

echo ""

target_pct="$(cat ${target_out} |sed -E 's/.*total.*\t([0-9.]+)%/\1/;t;d')"
head_pct="$(cat ${head_out} |sed -E 's/.*total.*\t([0-9.]+)%/\1/;t;d')"
target_pct="$(tail -n2 ${target_out} | head -n1 | sed -E 's/.*total.*\t([0-9.]+)%.*/\1/')"
head_pct="$(tail -n2 ${head_out} | head -n1 | sed -E 's/.*total.*\t([0-9.]+)%/\1/')"
echo "Results: ${target} ${target_pct}% HEAD ${head_pct}%"

delta_pct=$(echo "$head_pct - $target_pct" | bc -l)
echo "Delta: ${delta_pct}"

if [[ $delta_pct = \-* ]]; then
echo "Regression!";
return 1
echo "Regression!";

target_diff="${output_dir}/target.diff.txt"
head_diff="${output_dir}/head.diff.txt"
cat "${target_out}" | grep -E '^github.com/pelletier/go-toml' | tr -s "\t " | cut -f 2,3 | sort > "${target_diff}"
cat "${head_out}" | grep -E '^github.com/pelletier/go-toml' | tr -s "\t " | cut -f 2,3 | sort > "${head_diff}"

diff --side-by-side --suppress-common-lines "${target_diff}" "${head_diff}"
return 1
fi
return 0
;;
Expand Down
6 changes: 5 additions & 1 deletion decode.go
Expand Up @@ -130,7 +130,11 @@ func parseDateTime(b []byte) (time.Time, error) {
}

seconds := direction * (hours*3600 + minutes*60)
zone = time.FixedZone("", seconds)
if seconds == 0 {
zone = time.UTC
} else {
zone = time.FixedZone("", seconds)
}
b = b[dateTimeByteLen:]
}

Expand Down
56 changes: 56 additions & 0 deletions fuzz_test.go
@@ -0,0 +1,56 @@
//go:build go1.18
// +build go1.18

package toml_test

import (
"io/ioutil"
"strings"
"testing"

"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/require"
)

func FuzzUnmarshal(f *testing.F) {
file, err := ioutil.ReadFile("benchmark/benchmark.toml")
if err != nil {
panic(err)
}
f.Add(file)

f.Fuzz(func(t *testing.T, b []byte) {
if strings.Contains(string(b), "nan") {
// Current limitation of testify.
// https://github.com/stretchr/testify/issues/624
t.Skip("can't compare NaNs")
}

t.Log("INITIAL DOCUMENT ===========================")
t.Log(string(b))

var v interface{}
err := toml.Unmarshal(b, &v)
if err != nil {
return
}

t.Log("DECODED VALUE ===========================")
t.Logf("%#+v", v)

encoded, err := toml.Marshal(v)
if err != nil {
t.Fatalf("cannot marshal unmarshaled document: %s", err)
}

t.Log("ENCODED DOCUMENT ===========================")
t.Log(string(encoded))

var v2 interface{}
err = toml.Unmarshal(encoded, &v2)
if err != nil {
t.Fatalf("failed round trip: %s", err)
}
require.Equal(t, v, v2)
})
}
108 changes: 62 additions & 46 deletions marshaler.go
Expand Up @@ -208,11 +208,20 @@ func (ctx *encoderCtx) isRoot() bool {
}

func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if !v.IsZero() {
i, ok := v.Interface().(time.Time)
if ok {
return i.AppendFormat(b, time.RFC3339), nil
i := v.Interface()

switch x := i.(type) {
case time.Time:
if x.Nanosecond() > 0 {
return x.AppendFormat(b, time.RFC3339Nano), nil
}
return x.AppendFormat(b, time.RFC3339), nil
case LocalTime:
return append(b, x.String()...), nil
case LocalDate:
return append(b, x.String()...), nil
case LocalDateTime:
return append(b, x.String()...), nil
}

hasTextMarshaler := v.Type().Implements(textMarshalerType)
Expand Down Expand Up @@ -260,16 +269,31 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.String:
b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32:
if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 32)
f := v.Float()

if math.IsNaN(f) {
b = append(b, "nan"...)
} else if f > math.MaxFloat32 {
b = append(b, "inf"...)
} else if f < -math.MaxFloat32 {
b = append(b, "-inf"...)
} else if math.Trunc(f) == f {
b = strconv.AppendFloat(b, f, 'f', 1, 32)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
b = strconv.AppendFloat(b, f, 'f', -1, 32)
}
case reflect.Float64:
if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 64)
f := v.Float()
if math.IsNaN(f) {
b = append(b, "nan"...)
} else if f > math.MaxFloat64 {
b = append(b, "inf"...)
} else if f < -math.MaxFloat64 {
b = append(b, "-inf"...)
} else if math.Trunc(f) == f {
b = strconv.AppendFloat(b, f, 'f', 1, 64)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 64)
b = strconv.AppendFloat(b, f, 'f', -1, 64)
}
case reflect.Bool:
if v.Bool() {
Expand Down Expand Up @@ -300,10 +324,6 @@ func isNil(v reflect.Value) bool {
func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) {
var err error

if !ctx.hasKey {
panic("caller of encodeKv should have set the key in the context")
}

if (ctx.options.omitempty || options.omitempty) && isEmptyValue(v) {
return b, nil
}
Expand All @@ -313,12 +333,7 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
}

b = enc.indent(ctx.indent, b)

b, err = enc.encodeKey(b, ctx.key)
if err != nil {
return nil, err
}

b = enc.encodeKey(b, ctx.key)
b = append(b, " = "...)

// create a copy of the context because the value of a KV shouldn't
Expand Down Expand Up @@ -365,7 +380,13 @@ func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byt
}

func needsQuoting(v string) bool {
return strings.ContainsAny(v, "'\b\f\n\r\t")
// TODO: vectorize
for _, b := range []byte(v) {
if b == '\'' || b == '\r' || b == '\n' || invalidAscii(b) {
return true
}
}
return false
}

// caller should have checked that the string does not contain new lines or ' .
Expand Down Expand Up @@ -437,7 +458,7 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
return b
}

// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
// caller should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...)
}
Expand All @@ -453,20 +474,11 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)

b = append(b, '[')

var err error

b, err = enc.encodeKey(b, ctx.parentKey[0])
if err != nil {
return nil, err
}
b = enc.encodeKey(b, ctx.parentKey[0])

for _, k := range ctx.parentKey[1:] {
b = append(b, '.')

b, err = enc.encodeKey(b, k)
if err != nil {
return nil, err
}
b = enc.encodeKey(b, k)
}

b = append(b, "]\n"...)
Expand All @@ -475,33 +487,37 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
}

//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
func (enc *Encoder) encodeKey(b []byte, k string) []byte {
needsQuotation := false
cannotUseLiteral := false

if len(k) == 0 {
return append(b, "''"...)
}

for _, c := range k {
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
continue
}

if c == '\n' {
return nil, fmt.Errorf("toml: new line characters in keys are not supported")
}

if c == literalQuote {
cannotUseLiteral = true
}

needsQuotation = true
}

if needsQuotation && needsQuoting(k) {
cannotUseLiteral = true
}

switch {
case cannotUseLiteral:
return enc.encodeQuotedString(false, b, k), nil
return enc.encodeQuotedString(false, b, k)
case needsQuotation:
return enc.encodeLiteralString(b, k), nil
return enc.encodeLiteralString(b, k)
default:
return enc.encodeUnquotedKey(b, k), nil
return enc.encodeUnquotedKey(b, k)
}
}

Expand Down Expand Up @@ -803,6 +819,9 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) bool {
}

func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) bool {
if ctx.insideKv {
return false
}
t := v.Type()

if t.Kind() == reflect.Interface {
Expand Down Expand Up @@ -848,7 +867,6 @@ func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]by
func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
ctx.shiftKey()

var err error
scratch := make([]byte, 0, 64)
scratch = append(scratch, "[["...)

Expand All @@ -857,10 +875,7 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
scratch = append(scratch, '.')
}

scratch, err = enc.encodeKey(scratch, k)
if err != nil {
return nil, err
}
scratch = enc.encodeKey(scratch, k)
}

scratch = append(scratch, "]]\n"...)
Expand All @@ -869,6 +884,7 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
for i := 0; i < v.Len(); i++ {
b = append(b, scratch...)

var err error
b, err = enc.encode(b, ctx, v.Index(i))
if err != nil {
return nil, err
Expand Down