Skip to content

Commit

Permalink
Merge pull request #72 from tdakkota/perf/writeErr-check
Browse files Browse the repository at this point in the history
perf: fail early in string encoders
  • Loading branch information
tdakkota committed Feb 1, 2023
2 parents 047bbd6 + e0930c1 commit 9a31093
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 20 deletions.
13 changes: 13 additions & 0 deletions enc_stream_test.go
Expand Up @@ -101,3 +101,16 @@ func TestEncoder_ResetWriter(t *testing.T) {
require.Equal(t, expected, got.String())
}
}

// This benchmark is used to measure the overhead of ignoring errors.
func BenchmarkSkipError(b *testing.B) {
e := NewStreamingEncoder(io.Discard, 32)
e.w.stream.setError(errors.New("test"))

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
encodeObject(e)
}
}
11 changes: 11 additions & 0 deletions float_test.go
Expand Up @@ -10,6 +10,8 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/go-faster/errors"
)

// epsilon to compare floats.
Expand Down Expand Up @@ -152,6 +154,15 @@ func TestWriteFloat64(t *testing.T) {
should.Equal("1e-7", e.String())
}

func TestEncoder_FloatError(t *testing.T) {
e := NewStreamingEncoder(io.Discard, -1)
e.w.stream.setError(errors.New("foo"))

require.True(t, e.Float32(10))
require.True(t, e.Float64(10))
require.Error(t, e.Close())
}

func TestDecoder_FloatEOF(t *testing.T) {
d := GetDecoder()

Expand Down
4 changes: 2 additions & 2 deletions w_b64.go
Expand Up @@ -34,11 +34,11 @@ func (w *Writer) Base64(data []byte) bool {
}
e := stdbase64.NewEncoder(stdbase64.StdEncoding, s.writer)
if _, err := e.Write(data); err != nil {
s.writeErr = err
s.setError(err)
return true
}
if err := e.Close(); err != nil {
s.writeErr = err
s.setError(err)
return true
}
}
Expand Down
13 changes: 8 additions & 5 deletions w_float_bits.go
Expand Up @@ -20,14 +20,17 @@ func (w *Writer) Float(v float64, bits int) bool {
return w.Null()
}

if w.stream == nil {
switch s := w.stream; {
case s == nil:
w.Buf = floatAppend(w.Buf, v, bits)
return false
case s.fail():
return true
default:
tmp := make([]byte, 0, 32)
tmp = floatAppend(tmp, v, bits)
return writeStreamByteseq(w, tmp)
}

tmp := make([]byte, 0, 32)
tmp = floatAppend(tmp, v, bits)
return writeStreamByteseq(w, tmp)
}

func floatAppend(b []byte, v float64, bits int) []byte {
Expand Down
14 changes: 7 additions & 7 deletions w_str.go
Expand Up @@ -45,18 +45,18 @@ func writeStr[S byteseq.Byteseq](w *Writer, v S) (fail bool) {
var (
i = 0
length = len(v)
c byte
)
for i, c = range []byte(v) {
for ; i < length && !fail; i++ {
c := v[i]
if safeSet[c] != 0 {
goto slow
break
}
}
if i == length-1 {
return writeStreamByteseq(w, v) || w.byte('"')
fail = fail || writeStreamByteseq(w, v[:i])
if i == length {
return fail || w.byte('"')
}
slow:
return writeStreamByteseq(w, v[:i]) || strSlow[S](w, v[i:])
return fail || strSlow[S](w, v[i:])
}

func strSlow[S byteseq.Byteseq](w *Writer, v S) (fail bool) {
Expand Down
3 changes: 2 additions & 1 deletion w_str_escape.go
Expand Up @@ -124,12 +124,13 @@ func (w *Writer) ByteStrEscape(v []byte) bool {

func strEscape[S byteseq.Byteseq](w *Writer, v S) (fail bool) {
fail = w.byte('"')

// Fast path, probably does not require escaping.
var (
i = 0
length = len(v)
)
for ; i < length; i++ {
for ; i < length && !fail; i++ {
c := v[i]
if c >= utf8.RuneSelf || !(htmlSafeSet[c]) {
break
Expand Down
22 changes: 17 additions & 5 deletions w_stream.go
Expand Up @@ -44,18 +44,26 @@ func (s *streamState) Reset(w io.Writer) {
s.writeErr = nil
}

func (s *streamState) setError(err error) {
s.writeErr = err
}

func (s *streamState) fail() bool {
return s.writeErr != nil
}

func (s *streamState) flush(buf []byte) ([]byte, bool) {
if s.writeErr != nil {
if s.fail() {
return nil, true
}

var n int
n, s.writeErr = s.writer.Write(buf)
n, err := s.writer.Write(buf)
switch {
case s.writeErr != nil:
case err != nil:
s.setError(err)
return nil, true
case n != len(buf):
s.writeErr = io.ErrShortWrite
s.setError(io.ErrShortWrite)
return nil, true
default:
buf = buf[:0]
Expand All @@ -73,6 +81,10 @@ func writeStreamByteseq[S byteseq.Byteseq](w *Writer, s S) bool {
return false
}

if w.stream.fail() {
return true
}

for len(w.Buf)+len(s) > cap(w.Buf) {
var fail bool
w.Buf, fail = w.stream.flush(w.Buf)
Expand Down

0 comments on commit 9a31093

Please sign in to comment.