Skip to content

Commit

Permalink
Fix+reduce allocations (#668)
Browse files Browse the repository at this point in the history
Adds `WithDecodeBuffersBelow` to tweak buffer switch-over.

Fixes #666 and generally reduces Reader allocations.
  • Loading branch information
klauspost committed Sep 25, 2022
1 parent ef0aeb7 commit 3690e90
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 49 deletions.
21 changes: 15 additions & 6 deletions zstd/decoder.go
Expand Up @@ -35,6 +35,7 @@ type Decoder struct {
br readerWrapper
enabled bool
inFrame bool
dstBuf []byte
}

frame *frameDec
Expand Down Expand Up @@ -187,21 +188,23 @@ func (d *Decoder) Reset(r io.Reader) error {
}

// If bytes buffer and < 5MB, do sync decoding anyway.
if bb, ok := r.(byter); ok && bb.Len() < 5<<20 {
if bb, ok := r.(byter); ok && bb.Len() < d.o.decodeBufsBelow && !d.o.limitToCap {
bb2 := bb
if debugDecoder {
println("*bytes.Buffer detected, doing sync decode, len:", bb.Len())
}
b := bb2.Bytes()
var dst []byte
if cap(d.current.b) > 0 {
dst = d.current.b
if cap(d.syncStream.dstBuf) > 0 {
dst = d.syncStream.dstBuf[:0]
}

dst, err := d.DecodeAll(b, dst[:0])
dst, err := d.DecodeAll(b, dst)
if err == nil {
err = io.EOF
}
// Save output buffer
d.syncStream.dstBuf = dst
d.current.b = dst
d.current.err = err
d.current.flushed = true
Expand All @@ -216,6 +219,7 @@ func (d *Decoder) Reset(r io.Reader) error {
d.current.err = nil
d.current.flushed = false
d.current.d = nil
d.syncStream.dstBuf = nil

// Ensure no-one else is still running...
d.streamWg.Wait()
Expand Down Expand Up @@ -680,6 +684,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder {
println("Async 1: new history, recent:", block.async.newHist.recentOffsets)
}
hist.reset()
hist.decoders = block.async.newHist.decoders
hist.recentOffsets = block.async.newHist.recentOffsets
hist.windowSize = block.async.newHist.windowSize
Expand Down Expand Up @@ -711,6 +716,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
seqExecute <- block
}
close(seqExecute)
hist.reset()
}()

var wg sync.WaitGroup
Expand All @@ -734,6 +740,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder {
println("Async 2: new history")
}
hist.reset()
hist.windowSize = block.async.newHist.windowSize
hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer
if block.async.newHist.dict != nil {
Expand Down Expand Up @@ -815,13 +822,14 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder {
println("decoder goroutines finished")
}
hist.reset()
}()

var hist history
decodeStream:
for {
var hist history
var hasErr bool

hist.reset()
decodeBlock := func(block *blockDec) {
if hasErr {
if block != nil {
Expand Down Expand Up @@ -937,5 +945,6 @@ decodeStream:
}
close(seqDecode)
wg.Wait()
hist.reset()
d.frame.history.b = frameHistCache
}
34 changes: 24 additions & 10 deletions zstd/decoder_options.go
Expand Up @@ -14,21 +14,23 @@ type DOption func(*decoderOptions) error

// options retains accumulated state of multiple options.
type decoderOptions struct {
lowMem bool
concurrent int
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
limitToCap bool
lowMem bool
concurrent int
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
limitToCap bool
decodeBufsBelow int
}

func (o *decoderOptions) setDefault() {
*o = decoderOptions{
// use less ram: true for now, but may change.
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
maxWindowSize: MaxWindowSize,
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
maxWindowSize: MaxWindowSize,
decodeBufsBelow: 128 << 10,
}
if o.concurrent > 4 {
o.concurrent = 4
Expand Down Expand Up @@ -126,6 +128,18 @@ func WithDecodeAllCapLimit(b bool) DOption {
}
}

// WithDecodeBuffersBelow will fully decode readers that have a
// `Bytes() []byte` and `Len() int` interface similar to bytes.Buffer.
// This typically uses less allocations but will have the full decompressed object in memory.
// Note that DecodeAllCapLimit will disable this, as well as giving a size of 0 or less.
// Default is 128KiB.
func WithDecodeBuffersBelow(size int) DOption {
return func(o *decoderOptions) error {
o.decodeBufsBelow = size
return nil
}
}

// IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error {
Expand Down
55 changes: 43 additions & 12 deletions zstd/decoder_test.go
Expand Up @@ -1157,12 +1157,18 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error)

func BenchmarkDecoder_DecoderSmall(b *testing.B) {
zr := testCreateZipReader("testdata/benchdecoder.zip", b)
dec, err := NewReader(nil)
dec, err := NewReader(nil, WithDecodeBuffersBelow(1<<30))
if err != nil {
b.Fatal(err)
return
}
defer dec.Close()
dec2, err := NewReader(nil, WithDecodeBuffersBelow(0))
if err != nil {
b.Fatal(err)
return
}
defer dec2.Close()
for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".zst") {
continue
Expand All @@ -1183,6 +1189,7 @@ func BenchmarkDecoder_DecoderSmall(b *testing.B) {
in = append(in, in...)
// 8x
in = append(in, in...)

err = dec.Reset(bytes.NewBuffer(in))
if err != nil {
b.Fatal(err)
Expand All @@ -1191,19 +1198,43 @@ func BenchmarkDecoder_DecoderSmall(b *testing.B) {
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(got)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = dec.Reset(bytes.NewBuffer(in))
if err != nil {
b.Fatal(err)
b.Run("buffered", func(b *testing.B) {
b.SetBytes(int64(len(got)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = dec.Reset(bytes.NewBuffer(in))
if err != nil {
b.Fatal(err)
}
n, err := io.Copy(io.Discard, dec)
if err != nil {
b.Fatal(err)
}
if int(n) != len(got) {
b.Fatalf("want %d, got %d", len(got), n)
}

}
_, err := io.Copy(io.Discard, dec)
if err != nil {
b.Fatal(err)
})
b.Run("unbuffered", func(b *testing.B) {
b.SetBytes(int64(len(got)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = dec2.Reset(bytes.NewBuffer(in))
if err != nil {
b.Fatal(err)
}
n, err := io.Copy(io.Discard, dec2)
if err != nil {
b.Fatal(err)
}
if int(n) != len(got) {
b.Fatalf("want %d, got %d", len(got), n)
}
}
}
})
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions zstd/framedec.go
Expand Up @@ -343,7 +343,7 @@ func (d *frameDec) consumeCRC() error {
return nil
}

// runDecoder will create a sync decoder that will decode a block of data.
// runDecoder will run the decoder for the remainder of the frame.
func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
saved := d.history.b

Expand All @@ -369,7 +369,7 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen {
if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst)
Expand Down
21 changes: 9 additions & 12 deletions zstd/history.go
Expand Up @@ -37,24 +37,21 @@ func (h *history) reset() {
h.ignoreBuffer = 0
h.error = false
h.recentOffsets = [3]int{1, 4, 8}
if f := h.decoders.litLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
}
if f := h.decoders.offsets.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
}
if f := h.decoders.matchLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
}
h.decoders.freeDecoders()
h.decoders = sequenceDecs{br: h.decoders.br}
h.freeHuffDecoder()
h.huffTree = nil
h.dict = nil
//printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b))
}

func (h *history) freeHuffDecoder() {
if h.huffTree != nil {
if h.dict == nil || h.dict.litEnc != h.huffTree {
huffDecoderPool.Put(h.huffTree)
h.huffTree = nil
}
}
h.huffTree = nil
h.dict = nil
//printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b))
}

func (h *history) setDict(dict *dict) {
Expand Down
22 changes: 20 additions & 2 deletions zstd/seqdec.go
Expand Up @@ -99,6 +99,21 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) erro
return nil
}

func (s *sequenceDecs) freeDecoders() {
if f := s.litLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.litLengths.fse = nil
}
if f := s.offsets.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.offsets.fse = nil
}
if f := s.matchLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.matchLengths.fse = nil
}
}

// execute will execute the decoded sequence with the provided history.
// The sequence must be evaluated before being sent.
func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
Expand Down Expand Up @@ -299,7 +314,10 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
}
size := ll + ml + len(out)
if size-startSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", size-startSize, maxBlockSize)
if size-startSize == 424242 {
panic("here")
}
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
if size > cap(out) {
// Not enough size, which can happen under high volume block streaming conditions
Expand Down Expand Up @@ -411,7 +429,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {

// Check if space for literals
if size := len(s.literals) + len(s.out) - startSize; size > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", size, maxBlockSize)
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}

// Add final literals
Expand Down
7 changes: 4 additions & 3 deletions zstd/seqdec_amd64.go
Expand Up @@ -139,15 +139,16 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
if debugDecoder {
println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize)
}
return true, fmt.Errorf("output (%d) bigger than max block size (%d)", size-startSize, maxBlockSize)
return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)

default:
return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode)
}

s.seqSize += ctx.litRemain
if s.seqSize > maxBlockSize {
return true, fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)

}
err := br.close()
if err != nil {
Expand Down Expand Up @@ -289,7 +290,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {

s.seqSize += ctx.litRemain
if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
err := br.close()
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions zstd/seqdec_generic.go
Expand Up @@ -111,7 +111,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
s.seqSize += ll + ml
if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
litRemain -= ll
if litRemain < 0 {
Expand Down Expand Up @@ -149,7 +149,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
s.seqSize += litRemain
if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
err := br.close()
if err != nil {
Expand Down

0 comments on commit 3690e90

Please sign in to comment.