From 1a14c30993f62bb02102029f81a91fa442025424 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 18 Jan 2022 21:21:58 +0100 Subject: [PATCH 01/31] Refactor zstd decoder --- zstd/blockdec.go | 298 ++++++++++++++----------------------------- zstd/decoder.go | 233 ++++++++++++++++++++------------- zstd/decoder_test.go | 2 +- zstd/framedec.go | 50 +------- zstd/history.go | 23 +++- zstd/seqdec.go | 234 ++++++++++++++++++++++++++++++++- 6 files changed, 482 insertions(+), 358 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index dc587b2c94..a40284e48e 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -76,16 +76,19 @@ type blockDec struct { // Window size of the block. WindowSize uint64 - history chan *history - input chan struct{} - result chan decodeOutput - err error - decWG sync.WaitGroup + err error // Frame to use for singlethreaded decoding. // Should not be used by the decoder itself since parent may be another frame. localFrame *frameDec + async struct { + newHist *history + literals []byte + seqData []byte + sequence []seqVals + } + // Block is RLE, this is the size. RLESize uint32 tmp [4]byte @@ -108,13 +111,8 @@ func (b *blockDec) String() string { func newBlockDec(lowMem bool) *blockDec { b := blockDec{ - lowMem: lowMem, - result: make(chan decodeOutput, 1), - input: make(chan struct{}, 1), - history: make(chan *history, 1), + lowMem: lowMem, } - b.decWG.Add(1) - go b.startDecoder() return &b } @@ -192,85 +190,14 @@ func (b *blockDec) sendErr(err error) { b.Last = true b.Type = blockTypeReserved b.err = err - b.input <- struct{}{} } // Close will release resources. // Closed blockDec cannot be reset. func (b *blockDec) Close() { - close(b.input) - close(b.history) - close(b.result) - b.decWG.Wait() -} - -// decodeAsync will prepare decoding the block when it receives input. -// This will separate output and history. -func (b *blockDec) startDecoder() { - defer b.decWG.Done() - for range b.input { - //println("blockDec: Got block input") - switch b.Type { - case blockTypeRLE: - if cap(b.dst) < int(b.RLESize) { - if b.lowMem { - b.dst = make([]byte, b.RLESize) - } else { - b.dst = make([]byte, maxBlockSize) - } - } - o := decodeOutput{ - d: b, - b: b.dst[:b.RLESize], - err: nil, - } - v := b.data[0] - for i := range o.b { - o.b[i] = v - } - hist := <-b.history - hist.append(o.b) - b.result <- o - case blockTypeRaw: - o := decodeOutput{ - d: b, - b: b.data, - err: nil, - } - hist := <-b.history - hist.append(o.b) - b.result <- o - case blockTypeCompressed: - b.dst = b.dst[:0] - err := b.decodeCompressed(nil) - o := decodeOutput{ - d: b, - b: b.dst, - err: err, - } - if debugDecoder { - println("Decompressed to", len(b.dst), "bytes, error:", err) - } - b.result <- o - case blockTypeReserved: - // Used for returning errors. - <-b.history - b.result <- decodeOutput{ - d: b, - b: nil, - err: b.err, - } - default: - panic("Invalid block type") - } - if debugDecoder { - println("blockDec: Finished block") - } - } } -// decodeAsync will prepare decoding the block when it receives the history. -// If history is provided, it will not fetch it from the channel. +// decodeBuf func (b *blockDec) decodeBuf(hist *history) error { switch b.Type { case blockTypeRLE: @@ -310,25 +237,12 @@ func (b *blockDec) decodeBuf(hist *history) error { } } -// decodeCompressed will start decompressing a block. -// If no history is supplied the decoder will decodeAsync as much as possible -// before fetching from blockDec.history -func (b *blockDec) decodeCompressed(hist *history) error { - in := b.data - delayedHistory := hist == nil - - if delayedHistory { - // We must always grab history. - defer func() { - if hist == nil { - <-b.history - } - }() - } +func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain []byte, err error) { // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header if len(in) < 2 { - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } + litType := literalsBlockType(in[0] & 3) var litRegenSize int var litCompSize int @@ -349,7 +263,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes. if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12) in = in[3:] @@ -360,7 +274,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023). if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) litRegenSize = int(n & 1023) @@ -371,7 +285,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { fourStreams = true if len(in) < 4 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) litRegenSize = int(n & 16383) @@ -381,7 +295,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { fourStreams = true if len(in) < 5 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28) litRegenSize = int(n & 262143) @@ -392,13 +306,12 @@ func (b *blockDec) decodeCompressed(hist *history) error { if debugDecoder { println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams) } - var literals []byte - var huff *huff0.Scratch + switch litType { case literalsBlockRaw: if len(in) < litRegenSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } literals = in[:litRegenSize] in = in[litRegenSize:] @@ -406,7 +319,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { case literalsBlockRLE: if len(in) < 1 { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } if cap(b.literalBuf) < litRegenSize { if b.lowMem { @@ -417,7 +330,6 @@ func (b *blockDec) decodeCompressed(hist *history) error { b.literalBuf = make([]byte, litRegenSize) } else { b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize) - } } } @@ -433,7 +345,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { case literalsBlockTreeless: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } // Store compressed literals, so we defer decoding until we get history. literals = in[:litCompSize] @@ -441,15 +353,41 @@ func (b *blockDec) decodeCompressed(hist *history) error { if debugDecoder { printf("Found %d compressed literals\n", litCompSize) } + huff := hist.huffTree + if huff == nil { + return nil, in, errors.New("literal block was treeless, but no history was defined") + } + // Ensure we have space to store it. + if cap(b.literalBuf) < litRegenSize { + if b.lowMem { + b.literalBuf = make([]byte, 0, litRegenSize) + } else { + b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) + } + } + var err error + // Use our out buffer. + if fourStreams { + literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) + } else { + literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) + } + // Make sure we don't leak our literals buffer + if err != nil { + println("decompressing literals:", err) + return nil, in, err + } + if len(literals) != litRegenSize { + return nil, in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + } + case literalsBlockCompressed: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return ErrBlockTooSmall + return nil, in, ErrBlockTooSmall } literals = in[:litCompSize] in = in[litCompSize:] - huff = huffDecoderPool.Get().(*huff0.Scratch) - var err error // Ensure we have space to store it. if cap(b.literalBuf) < litRegenSize { if b.lowMem { @@ -458,14 +396,20 @@ func (b *blockDec) decodeCompressed(hist *history) error { b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) } } + huff := hist.huffTree if huff == nil { - huff = &huff0.Scratch{} + huff = huffDecoderPool.Get().(*huff0.Scratch) + if huff == nil { + huff = &huff0.Scratch{} + } } + var err error huff, literals, err = huff0.ReadTable(literals, huff) if err != nil { println("reading huffman table:", err) - return err + return nil, in, err } + hist.huffTree = huff // Use our out buffer. if fourStreams { literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) @@ -474,17 +418,29 @@ func (b *blockDec) decodeCompressed(hist *history) error { } if err != nil { println("decoding compressed literals:", err) - return err + return nil, in, err } // Make sure we don't leak our literals buffer if len(literals) != litRegenSize { - return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + return nil, in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) } if debugDecoder { printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize) } } + return literals, in, nil +} +// decodeCompressed will start decompressing a block. +func (b *blockDec) decodeCompressed(hist *history) error { + in := b.data + literals, in, err := b.decodeLiterals(in, hist) + if err != nil { + return err + } +} + +func (b *blockDec) decodeSequences(in, literals []byte, hist *history) error { // Decode Sequences // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section if len(in) < 1 { @@ -512,7 +468,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { in = in[3:] } - var seqs = &sequenceDecs{} + var seqs = &hist.decoders if nSeqs > 0 { if len(in) < 1 { return ErrBlockTooSmall @@ -586,128 +542,61 @@ func (b *blockDec) decodeCompressed(hist *history) error { in = br.unread() } - // Wait for history. - // All time spent after this is critical since it is strictly sequential. - if hist == nil { - hist = <-b.history - if hist.error { - return ErrDecoderClosed - } - } - - // Decode treeless literal block. - if litType == literalsBlockTreeless { - // TODO: We could send the history early WITHOUT the stream history. - // This would allow decoding treeless literals before the byte history is available. - // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless. - // So not much obvious gain here. - - if hist.huffTree == nil { - return errors.New("literal block was treeless, but no history was defined") - } - // Ensure we have space to store it. - if cap(b.literalBuf) < litRegenSize { - if b.lowMem { - b.literalBuf = make([]byte, 0, litRegenSize) - } else { - b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) - } - } - var err error - // Use our out buffer. - huff = hist.huffTree - if fourStreams { - literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) - } else { - literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) - } - // Make sure we don't leak our literals buffer - if err != nil { - println("decompressing literals:", err) - return err - } - if len(literals) != litRegenSize { - return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) - } - } else { - if hist.huffTree != nil && huff != nil { - if hist.dict == nil || hist.dict.litEnc != hist.huffTree { - huffDecoderPool.Put(hist.huffTree) - } - hist.huffTree = nil - } - } - if huff != nil { - hist.huffTree = huff - } if debugDecoder { println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.") } if nSeqs == 0 { // Decompressed content is defined entirely as Literals Section content. - b.dst = append(b.dst, literals...) - if delayedHistory { - hist.append(literals) - } return nil } - seqs, err := seqs.mergeHistory(&hist.decoders) - if err != nil { - return err - } - if debugDecoder { - println("History merged ok") - } br := &bitReader{} if err := br.init(in); err != nil { return err } - // TODO: Investigate if sending history without decoders are faster. - // This would allow the sequences to be decoded async and only have to construct stream history. - // If only recent offsets were not transferred, this would be an obvious win. - // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded. - - hbytes := hist.b - if len(hbytes) > hist.windowSize { - hbytes = hbytes[len(hbytes)-hist.windowSize:] - // We do not need history any more. - if hist.dict != nil { - hist.dict.content = nil - } - } - if err := seqs.initialize(br, hist, literals, b.dst); err != nil { println("initializing sequences:", err) return err } - err = seqs.decode(nSeqs, br, hbytes) + err := seqs.decode(nSeqs, br) if err != nil { return err } if !br.finished() { return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) } - err = br.close() if err != nil { printf("Closing sequences: %v, %+v\n", err, *br) } + return err +} + +func (b *blockDec) executeSequences(hist *history) error { + hbytes := hist.b + if len(hbytes) > hist.windowSize { + hbytes = hbytes[len(hbytes)-hist.windowSize:] + // We do not need history anymore. + if hist.dict != nil { + hist.dict.content = nil + } + } + err = seqs.execute(hist.b) + if err != nil { + return err + } + if len(b.data) > maxCompressedBlockSize { return fmt.Errorf("compressed block size too large (%d)", len(b.data)) } // Set output and release references. b.dst = seqs.out - seqs.out, seqs.literals, seqs.hist = nil, nil, nil + seqs.out, seqs.literals = nil, nil - if !delayedHistory { - // If we don't have delayed history, no need to update. - hist.recentOffsets = seqs.prevOffset - return nil - } + hist.recentOffsets = seqs.prevOffset if b.Last { // if last block we don't care about history. println("Last block, no history returned") @@ -715,7 +604,6 @@ func (b *blockDec) decodeCompressed(hist *history) error { return nil } hist.append(b.dst) - hist.recentOffsets = seqs.prevOffset if debugDecoder { println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.") } diff --git a/zstd/decoder.go b/zstd/decoder.go index f430f58b57..044bb99db4 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -5,6 +5,7 @@ package zstd import ( + "context" "errors" "io" "sync" @@ -22,9 +23,6 @@ type Decoder struct { // Unreferenced decoders, ready for use. decoders chan *blockDec - // Streams ready to be decoded. - stream chan decodeStream - // Current read position used for Reader functionality. current decoderState @@ -46,7 +44,7 @@ type decoderState struct { output chan decodeOutput // cancel remaining output. - cancel chan struct{} + cancel context.CancelFunc flushed bool } @@ -196,24 +194,17 @@ func (d *Decoder) Reset(r io.Reader) error { return nil } - if d.stream == nil { - d.stream = make(chan decodeStream, 1) - d.streamWg.Add(1) - go d.startStreamDecoder(d.stream) - } + ctx, cancel := context.WithCancel(context.Background()) + d.streamWg.Add(1) + go d.startStreamDecoder(ctx, r, d.current.output) // Remove current block. d.current.decodeOutput = decodeOutput{} d.current.err = nil - d.current.cancel = make(chan struct{}) + d.current.cancel = cancel d.current.flushed = false d.current.d = nil - d.stream <- decodeStream{ - r: r, - output: d.current.output, - cancel: d.current.cancel, - } return nil } @@ -221,7 +212,7 @@ func (d *Decoder) Reset(r io.Reader) error { func (d *Decoder) drainOutput() { if d.current.cancel != nil { println("cancelling current") - close(d.current.cancel) + d.current.cancel() d.current.cancel = nil } if d.current.d != nil { @@ -327,7 +318,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { return dst, ErrDecoderSizeExceeded } if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 { - // Never preallocate moe than 1 GB up front. + // Never preallocate more than 1 GB up front. if cap(dst)-len(dst) < int(frame.FrameContentSize) { dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)) copy(dst2, dst) @@ -402,10 +393,10 @@ func (d *Decoder) Close() { return } d.drainOutput() - if d.stream != nil { - close(d.stream) + if d.current.cancel != nil { + d.current.cancel() d.streamWg.Wait() - d.stream = nil + d.current.cancel = nil } if d.decoders != nil { close(d.decoders) @@ -470,86 +461,144 @@ type decodeStream struct { var errEndOfStream = errors.New("end-of-stream") // Create Decoder: -// Spawn n block decoders. These accept tasks to decode a block. -// Create goroutine that handles stream processing, this will send history to decoders as they are available. -// Decoders update the history as they decode. -// When a block is returned: -// a) history is sent to the next decoder, -// b) content written to CRC. -// c) return data to WRITER. -// d) wait for next block to return data. -// Once WRITTEN, the decoders reused by the writer frame decoder for re-use. -func (d *Decoder) startStreamDecoder(inStream chan decodeStream) { +// ASYNC: +// Spawn 3 go routines. +// 1: Decode block and literals. Receives hufftree and seqdecs, returns seqdecs and huff tree. +// 2: Wait for recentOffsets if needed. Decode sequences, send recentOffsets. +// 3: Wait for stream history, execute sequences, send stream history. +func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) { defer d.streamWg.Done() frame := newFrameDec(d.o) - for stream := range inStream { - if debugDecoder { - println("got new stream") + + br := readerWrapper{r: r} + + // TODO: Needed? + frame.initAsync() + + var seqPrepare = make(chan *blockDec, d.o.concurrent) + var seqDecode = make(chan *blockDec, d.o.concurrent) + var seqExecute = make(chan *blockDec, d.o.concurrent) + // Async 1: Prepare blocks... + go func() { + var hist history + for block := range seqPrepare { + if block.err != nil { + seqDecode <- block + continue + } + if block.async.newHist != nil { + hist.huffTree = block.async.newHist.huffTree + } + literals, remain, err := block.decodeLiterals(block.data, &hist) + block.err = err + if err == nil { + block.async.literals = literals + block.async.seqData = remain + } + seqDecode <- block } - br := readerWrapper{r: stream.r} - decodeStream: - for { - frame.history.reset() - err := frame.reset(&br) - if debugDecoder && err != nil { - println("Frame decoder returned", err) + close(seqDecode) + }() + // Async 2: Decode sequences... + go func() { + var hist history + + for block := range seqDecode { + if block.err != nil { + seqExecute <- block + continue } - if err == nil && frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - err = ErrUnknownDictionary - } else { - frame.history.setDict(&dict) - } + if block.async.newHist != nil { + hist.decoders = block.async.newHist.decoders + hist.recentOffsets = block.async.newHist.recentOffsets } - if err != nil { - stream.output <- decodeOutput{ - err: err, - } - break + hist.decoders.literals = block.async.literals + block.err = block.decodeSequences(block.async.seqData, block.async.literals, &hist) + seqExecute <- block + } + close(seqExecute) + }() + // Async 3: Execute sequences... + go func() { + var hist history + + for block := range seqExecute { + out := decodeOutput{err: block.err, d: block} + if block.err != nil { + output <- out + continue } - if debugDecoder { - println("starting frame decoder") + if block.async.newHist != nil { + hist.dict = block.async.newHist.dict + hist.b = append(hist.b[:0], block.async.newHist.b...) } - // This goroutine will forward history between frames. - frame.frameDone.Add(1) - frame.initAsync() - - go frame.startDecoder(stream.output) - decodeFrame: - // Go through all blocks of the frame. - for { - dec := <-d.decoders - select { - case <-stream.cancel: - if !frame.sendErr(dec, io.EOF) { - // To not let the decoder dangle, send it back. - stream.output <- decodeOutput{d: dec} - } - break decodeStream - default: - } - err := frame.next(dec) - switch err { - case io.EOF: - // End of current frame, no error - println("EOF on next block") - break decodeFrame - case nil: - continue - default: - println("block decoder returned", err) - break decodeStream - } + } + close(output) + }() + +decodeStream: + for { + var historySent bool + frame.history.reset() + err := frame.reset(&br) + switch err { + case io.EOF: + break decodeStream + default: + dec := <-d.decoders + dec.sendErr(err) + seqPrepare <- dec + break decodeStream + case nil: + } + if debugDecoder && err != nil { + println("Frame decoder returned", err) + } + if frame.DictionaryID != nil { + dict, ok := d.dicts[*frame.DictionaryID] + if !ok { + err = ErrUnknownDictionary + } else { + frame.history.setDict(&dict) } - // All blocks have started decoding, check if there are more frames. - println("waiting for done") - frame.frameDone.Wait() - println("done waiting...") - } - frame.frameDone.Wait() - println("Sending EOS") - stream.output <- decodeOutput{err: errEndOfStream} + } + if err != nil { + output <- decodeOutput{ + err: err, + } + return + } + decodeFrame: + // Go through all blocks of the frame. + for { + dec := <-d.decoders + select { + case <-ctx.Done(): + dec.sendErr(ctx.Err()) + seqPrepare <- dec + break decodeFrame + default: + } + err := frame.next(dec) + if !historySent { + h := frame.history + dec.async.newHist = &h + historySent = true + } + seqPrepare <- dec + switch err { + case io.EOF: + // End of current frame, no error + println("EOF on next block") + break decodeFrame + case nil: + continue + default: + println("block decoder returned", err) + break decodeStream + } + } } + close(seqPrepare) } diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 5af4c406ca..7faaf49219 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1362,7 +1362,7 @@ func BenchmarkDecoderSilesia(b *testing.B) { } func BenchmarkDecoderEnwik9(b *testing.B) { - fn := "testdata/enwik9-1.zst" + fn := "testdata/enwik8.zst" data, err := ioutil.ReadFile(fn) if err != nil { if os.IsNotExist(err) { diff --git a/zstd/framedec.go b/zstd/framedec.go index 989c79f8c3..133dea147a 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -10,7 +10,6 @@ import ( "errors" "hash" "io" - "sync" "github.com/klauspost/compress/zstd/internal/xxhash" ) @@ -22,9 +21,6 @@ type frameDec struct { WindowSize uint64 - // In order queue of blocks being decoded. - decoding chan *blockDec - // Frame history passed between blocks history history @@ -34,15 +30,10 @@ type frameDec struct { bBuf byteBuf FrameContentSize uint64 - frameDone sync.WaitGroup DictionaryID *uint32 HasCheckSum bool SingleSegment bool - - // asyncRunning indicates whether the async routine processes input on 'decoding'. - asyncRunningMu sync.Mutex - asyncRunning bool } const ( @@ -282,43 +273,12 @@ func (d *frameDec) next(block *blockDec) error { if err != nil { println("block error:", err) // Signal the frame decoder we have a problem. - d.sendErr(block, err) + block.sendErr(err) return err } - block.input <- struct{}{} - if debugDecoder { - println("next block:", block) - } - d.asyncRunningMu.Lock() - defer d.asyncRunningMu.Unlock() - if !d.asyncRunning { - return nil - } - if block.Last { - // We indicate the frame is done by sending io.EOF - d.decoding <- block - return io.EOF - } - d.decoding <- block return nil } -// sendEOF will queue an error block on the frame. -// This will cause the frame decoder to return when it encounters the block. -// Returns true if the decoder was added. -func (d *frameDec) sendErr(block *blockDec, err error) bool { - d.asyncRunningMu.Lock() - defer d.asyncRunningMu.Unlock() - if !d.asyncRunning { - return false - } - - println("sending error", err.Error()) - block.sendErr(err) - d.decoding <- block - return true -} - // checkCRC will check the checksum if the frame has one. // Will return ErrCRCMismatch if crc check failed, otherwise nil. func (d *frameDec) checkCRC() error { @@ -364,18 +324,13 @@ func (d *frameDec) initAsync() { if cap(d.history.b) < d.history.maxSize { d.history.b = make([]byte, 0, d.history.maxSize) } - if cap(d.decoding) < d.o.concurrent { - d.decoding = make(chan *blockDec, d.o.concurrent) - } if debugDecoder { h := d.history printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) } - d.asyncRunningMu.Lock() - d.asyncRunning = true - d.asyncRunningMu.Unlock() } +/* // startDecoder will start decoding blocks and write them to the writer. // The decoder will stop as soon as an error occurs or at end of frame. // When the frame has finished decoding the *bufio.Reader @@ -470,6 +425,7 @@ func (d *frameDec) startDecoder(output chan decodeOutput) { block = next } } +*/ // runDecoder will create a sync decoder that will decode a block of data. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { diff --git a/zstd/history.go b/zstd/history.go index f783e32d25..44f79fa9e2 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -10,14 +10,23 @@ import ( // history contains the information transferred between blocks. type history struct { - b []byte - huffTree *huff0.Scratch + // Needed first, if needed. + huffTree *huff0.Scratch + + // Needed second, if needed... + decoders sequenceDecs + + // Maybe needed... recentOffsets [3]int - decoders sequenceDecs - windowSize int - maxSize int - error bool - dict *dict + + // Needed last... + b []byte + bCh chan []byte + + windowSize int + maxSize int + error bool + dict *dict } // reset will reset the history to initial state of a frame. diff --git a/zstd/seqdec.go b/zstd/seqdec.go index bc731e4cb6..28dc8d5ecf 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -20,6 +20,10 @@ type seq struct { llCode, mlCode, ofCode uint8 } +type seqVals struct { + ll, ml, mo int +} + func (s seq) String() string { if s.offset <= 3 { if s.offset == 0 { @@ -61,10 +65,11 @@ type sequenceDecs struct { offsets sequenceDec matchLengths sequenceDec prevOffset [3]int - hist []byte dict []byte literals []byte out []byte + seq []seqVals + seqSize int windowSize int maxBits uint8 } @@ -81,7 +86,6 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] return errors.New("matchLengths:" + err.Error()) } s.literals = literals - s.hist = hist.b s.prevOffset = hist.recentOffsets s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits s.windowSize = hist.windowSize @@ -94,7 +98,225 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] } // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { +func (s *sequenceDecs) decode(seqs int, br *bitReader) error { + // Grab full sizes tables, to avoid bounds checks. + llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] + llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state + if cap(s.seq) < seqs { + s.seq = make([]seqVals, 0, seqs) + } + s.seq = s.seq[:seqs] + s.seqSize = 0 + litRemain := len(s.literals) + + for i := range s.seq { + var ll, mo, ml int + if br.off > 4+((maxOffsetBits+16+16)>>3) { + // inlined function: + // ll, mo, ml = s.nextFast(br, llState, mlState, ofState) + + // Final will not read from stream. + var llB, mlB, moB uint8 + ll, llB = llState.final() + ml, mlB = mlState.final() + mo, moB = ofState.final() + + // extra bits are stored in reverse order. + br.fillFast() + mo += br.getBits(moB) + if s.maxBits > 32 { + br.fillFast() + } + ml += br.getBits(mlB) + ll += br.getBits(llB) + + if moB > 1 { + s.prevOffset[2] = s.prevOffset[1] + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = mo + } else { + // mo = s.adjustOffset(mo, ll, moB) + // Inlined for rather big speedup + if ll == 0 { + // There is an exception though, when current sequence's literals_length = 0. + // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2, + // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte. + mo++ + } + + if mo == 0 { + mo = s.prevOffset[0] + } else { + var temp int + if mo == 3 { + temp = s.prevOffset[0] - 1 + } else { + temp = s.prevOffset[mo] + } + + if temp == 0 { + // 0 is not valid; input is corrupted; force offset to 1 + println("temp was 0") + temp = 1 + } + + if mo != 1 { + s.prevOffset[2] = s.prevOffset[1] + } + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = temp + mo = temp + } + } + br.fillFast() + } else { + if br.overread() { + printf("reading sequence %d, exceeded available data\n", i) + return io.ErrUnexpectedEOF + } + ll, mo, ml = s.next(br, llState, mlState, ofState) + br.fill() + } + + if debugSequences { + println("Seq", i, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml) + } + // Evaluate. + // We might be doing this async, so do it early. + if mo == 0 && ml > 0 { + return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + } + if ml > maxMatchLen { + return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + } + s.seqSize += ll + ml + if s.seqSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) + } + litRemain -= ll + if litRemain < 0 { + return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain) + } + s.seq[i] = seqVals{ + ll: ll, + ml: ml, + mo: mo, + } + if i == len(s.seq)-1 { + // This is the last sequence, so we shouldn't update state. + break + } + + // Manually inlined, ~ 5-20% faster + // Update all 3 states at once. Approx 20% faster. + nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits() + if nBits == 0 { + llState = llTable[llState.newState()&maxTableMask] + mlState = mlTable[mlState.newState()&maxTableMask] + ofState = ofTable[ofState.newState()&maxTableMask] + } else { + bits := br.get32BitsFast(nBits) + lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) + llState = llTable[(llState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits >> (ofState.nbBits() & 31)) + lowBits &= bitMask[mlState.nbBits()&15] + mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits) & bitMask[ofState.nbBits()&15] + ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask] + } + } + s.seqSize += litRemain + if s.seqSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) + } + return nil +} + +// execute will execute the decoded sequence with the provided history. +func (s *sequenceDecs) execute(hist []byte) error { + if len(s.out)+s.seqSize > cap(s.out) { + addBytes := s.seqSize + len(s.out) + s.out = append(s.out, make([]byte, addBytes)...) + s.out = s.out[:len(s.out)-addBytes] + } + + for _, seq := range s.seq { + ll, ml, mo := seq.ll, seq.ml, seq.mo + + // Add literals + s.out = append(s.out, s.literals[:ll]...) + s.literals = s.literals[ll:] + out := s.out + + if mo > len(s.out)+len(hist) || mo > s.windowSize { + if len(s.dict) == 0 { + return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + } + + // we may be in dictionary. + dictO := len(s.dict) - (mo - (len(s.out) + len(hist))) + if dictO < 0 || dictO >= len(s.dict) { + return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + } + end := dictO + ml + if end > len(s.dict) { + out = append(out, s.dict[dictO:]...) + mo -= len(s.dict) - dictO + ml -= len(s.dict) - dictO + } else { + out = append(out, s.dict[dictO:end]...) + mo = 0 + ml = 0 + } + } + + // Copy from history. + // TODO: Blocks without history could be made to ignore this completely. + if v := mo - len(s.out); v > 0 { + // v is the start position in history from end. + start := len(hist) - v + if ml > v { + // Some goes into current block. + // Copy remainder of history + out = append(out, hist[start:]...) + mo -= v + ml -= v + } else { + out = append(out, hist[start:start+ml]...) + ml = 0 + } + } + // We must be in current buffer now + if ml > 0 { + start := len(s.out) - mo + if ml <= len(s.out)-start { + // No overlap + out = append(out, s.out[start:start+ml]...) + } else { + // Overlapping copy + // Extend destination slice and copy one byte at the time. + out = out[:len(out)+ml] + src := out[start : start+ml] + // Destination is the space we just added. + dst := out[len(out)-ml:] + dst = dst[:len(src)] + for i := range src { + dst[i] = src[i] + } + } + } + s.out = out + } + // Add final literals + s.out = append(s.out, s.literals...) + + return nil +} + +// decode sequences from the stream with the provided history. +func (s *sequenceDecs) decodeSync(seqs int, br *bitReader, hist []byte) error { startSize := len(s.out) // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] @@ -233,15 +455,15 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { // TODO: Blocks without history could be made to ignore this completely. if v := mo - len(s.out); v > 0 { // v is the start position in history from end. - start := len(s.hist) - v + start := len(hist) - v if ml > v { // Some goes into current block. // Copy remainder of history - out = append(out, s.hist[start:]...) + out = append(out, hist[start:]...) mo -= v ml -= v } else { - out = append(out, s.hist[start:start+ml]...) + out = append(out, hist[start:start+ml]...) ml = 0 } } From 984fc8f11b621606c3a6cd6a768587a7b8e93cab Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 11 Feb 2022 16:12:02 +0100 Subject: [PATCH 02/31] Make it compile --- zstd/blockdec.go | 89 +++++++++++++++++++++++++----------------------- zstd/decoder.go | 6 ++-- zstd/seqdec.go | 23 ++++++++++--- 3 files changed, 68 insertions(+), 50 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index a40284e48e..d930c0e8a2 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -237,10 +237,10 @@ func (b *blockDec) decodeBuf(hist *history) error { } } -func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain []byte, err error) { +func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err error) { // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header if len(in) < 2 { - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } litType := literalsBlockType(in[0] & 3) @@ -248,6 +248,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] var litCompSize int sizeFormat := (in[0] >> 2) & 3 var fourStreams bool + var literals []byte switch litType { case literalsBlockRaw, literalsBlockRLE: switch sizeFormat { @@ -263,7 +264,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes. if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12) in = in[3:] @@ -274,7 +275,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023). if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) litRegenSize = int(n & 1023) @@ -285,7 +286,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] fourStreams = true if len(in) < 4 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) litRegenSize = int(n & 16383) @@ -295,7 +296,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] fourStreams = true if len(in) < 5 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28) litRegenSize = int(n & 262143) @@ -311,7 +312,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] case literalsBlockRaw: if len(in) < litRegenSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } literals = in[:litRegenSize] in = in[litRegenSize:] @@ -319,7 +320,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] case literalsBlockRLE: if len(in) < 1 { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } if cap(b.literalBuf) < litRegenSize { if b.lowMem { @@ -345,7 +346,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] case literalsBlockTreeless: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } // Store compressed literals, so we defer decoding until we get history. literals = in[:litCompSize] @@ -355,7 +356,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] } huff := hist.huffTree if huff == nil { - return nil, in, errors.New("literal block was treeless, but no history was defined") + return in, errors.New("literal block was treeless, but no history was defined") } // Ensure we have space to store it. if cap(b.literalBuf) < litRegenSize { @@ -375,16 +376,16 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] // Make sure we don't leak our literals buffer if err != nil { println("decompressing literals:", err) - return nil, in, err + return in, err } if len(literals) != litRegenSize { - return nil, in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) } case literalsBlockCompressed: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return nil, in, ErrBlockTooSmall + return in, ErrBlockTooSmall } literals = in[:litCompSize] in = in[litCompSize:] @@ -407,7 +408,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] huff, literals, err = huff0.ReadTable(literals, huff) if err != nil { println("reading huffman table:", err) - return nil, in, err + return in, err } hist.huffTree = huff // Use our out buffer. @@ -418,36 +419,46 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (literals, remain [] } if err != nil { println("decoding compressed literals:", err) - return nil, in, err + return in, err } // Make sure we don't leak our literals buffer if len(literals) != litRegenSize { - return nil, in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) } if debugDecoder { printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize) } } - return literals, in, nil + hist.decoders.literals = literals + return in, nil } // decodeCompressed will start decompressing a block. func (b *blockDec) decodeCompressed(hist *history) error { in := b.data - literals, in, err := b.decodeLiterals(in, hist) + in, err := b.decodeLiterals(in, hist) if err != nil { return err } + err = b.prepareSequences(in, hist) + if err != nil { + return err + } + err = hist.decoders.decodeSync(hist.b) + if err != nil { + return err + } + return b.updateHistory(hist) } -func (b *blockDec) decodeSequences(in, literals []byte, hist *history) error { +func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { // Decode Sequences // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section if len(in) < 1 { return ErrBlockTooSmall } + var nSeqs int seqHeader := in[0] - nSeqs := 0 switch { case seqHeader == 0: in = in[1:] @@ -469,6 +480,7 @@ func (b *blockDec) decodeSequences(in, literals []byte, hist *history) error { } var seqs = &hist.decoders + seqs.nSeqs = nSeqs if nSeqs > 0 { if len(in) < 1 { return ErrBlockTooSmall @@ -541,38 +553,27 @@ func (b *blockDec) decodeSequences(in, literals []byte, hist *history) error { } in = br.unread() } - if debugDecoder { - println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.") + println("Literals:", len(seqs.literals), "hash:", xxhash.Sum64(seqs.literals), "and", seqs.nSeqs, "sequences.") } if nSeqs == 0 { - // Decompressed content is defined entirely as Literals Section content. return nil } - br := &bitReader{} if err := br.init(in); err != nil { return err } - if err := seqs.initialize(br, hist, literals, b.dst); err != nil { + if err := seqs.initialize(br, hist, b.dst); err != nil { println("initializing sequences:", err) return err } + return nil +} - err := seqs.decode(nSeqs, br) - if err != nil { - return err - } - if !br.finished() { - return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) - } - err = br.close() - if err != nil { - printf("Closing sequences: %v, %+v\n", err, *br) - } - return err +func (b *blockDec) decodeSequences(hist *history) error { + return hist.decoders.decode() } func (b *blockDec) executeSequences(hist *history) error { @@ -584,19 +585,22 @@ func (b *blockDec) executeSequences(hist *history) error { hist.dict.content = nil } } - err = seqs.execute(hist.b) + err := hist.decoders.execute(hist.b) if err != nil { return err } + return b.updateHistory(hist) +} +func (b *blockDec) updateHistory(hist *history) error { if len(b.data) > maxCompressedBlockSize { return fmt.Errorf("compressed block size too large (%d)", len(b.data)) } // Set output and release references. - b.dst = seqs.out - seqs.out, seqs.literals = nil, nil + b.dst = hist.decoders.out + hist.decoders.out, hist.decoders.literals = nil, nil - hist.recentOffsets = seqs.prevOffset + hist.recentOffsets = hist.decoders.prevOffset if b.Last { // if last block we don't care about history. println("Last block, no history returned") @@ -605,8 +609,7 @@ func (b *blockDec) executeSequences(hist *history) error { } hist.append(b.dst) if debugDecoder { - println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.") + println("Finished block with literals:", len(hist.decoders.literals), "and", len(hist.decoders.seq), "sequences.") } - return nil } diff --git a/zstd/decoder.go b/zstd/decoder.go index 044bb99db4..b4aeb98e49 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -489,10 +489,10 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch if block.async.newHist != nil { hist.huffTree = block.async.newHist.huffTree } - literals, remain, err := block.decodeLiterals(block.data, &hist) + remain, err := block.decodeLiterals(block.data, &hist) block.err = err if err == nil { - block.async.literals = literals + block.async.literals = hist.decoders.literals block.async.seqData = remain } seqDecode <- block @@ -513,7 +513,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch hist.recentOffsets = block.async.newHist.recentOffsets } hist.decoders.literals = block.async.literals - block.err = block.decodeSequences(block.async.seqData, block.async.literals, &hist) + block.err = block.decodeSequences(&hist) seqExecute <- block } close(seqExecute) diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 28dc8d5ecf..a773e4ddcb 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -69,13 +69,15 @@ type sequenceDecs struct { literals []byte out []byte seq []seqVals + nSeqs int + br *bitReader seqSize int windowSize int maxBits uint8 } // initialize all 3 decoders from the stream input. -func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error { +func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) error { if err := s.litLengths.init(br); err != nil { return errors.New("litLengths:" + err.Error()) } @@ -85,7 +87,7 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] if err := s.matchLengths.init(br); err != nil { return errors.New("matchLengths:" + err.Error()) } - s.literals = literals + s.br = br s.prevOffset = hist.recentOffsets s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits s.windowSize = hist.windowSize @@ -98,7 +100,10 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] } // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decode(seqs int, br *bitReader) error { +func (s *sequenceDecs) decode() error { + seqs := s.nSeqs + br := s.br + // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state @@ -231,10 +236,18 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader) error { if s.seqSize > maxBlockSize { return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) } + if !br.finished() { + return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + } return nil } // execute will execute the decoded sequence with the provided history. +// The sequence must be evaluated before being sent. func (s *sequenceDecs) execute(hist []byte) error { if len(s.out)+s.seqSize > cap(s.out) { addBytes := s.seqSize + len(s.out) @@ -316,7 +329,9 @@ func (s *sequenceDecs) execute(hist []byte) error { } // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decodeSync(seqs int, br *bitReader, hist []byte) error { +func (s *sequenceDecs) decodeSync(hist []byte) error { + br := s.br + seqs := s.nSeqs startSize := len(s.out) // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] From dca58e6078cc25818929522ea62bea9efd91f155 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 11 Feb 2022 17:51:59 +0100 Subject: [PATCH 03/31] Fix up single decodes. --- zstd/blockdec.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index d930c0e8a2..af79f4ab6b 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -444,11 +444,17 @@ func (b *blockDec) decodeCompressed(hist *history) error { if err != nil { return err } + if hist.decoders.nSeqs == 0 { + b.dst = append(b.dst, hist.decoders.literals...) + return nil + } err = hist.decoders.decodeSync(hist.b) if err != nil { return err } - return b.updateHistory(hist) + b.dst = hist.decoders.out + hist.recentOffsets = hist.decoders.prevOffset + return nil } func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { From 3c54049e86a4657fb3a7aa392c3d4b8023715316 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 16 Feb 2022 19:49:44 +0100 Subject: [PATCH 04/31] Almost working now... --- zstd/blockdec.go | 46 ++++++---- zstd/decoder.go | 196 ++++++++++++++++++++++++++++++++++--------- zstd/decoder_test.go | 14 ++-- zstd/framedec.go | 31 ++----- zstd/history.go | 21 ++--- zstd/seqdec.go | 7 +- zstd/zstd.go | 2 +- 7 files changed, 219 insertions(+), 98 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index af79f4ab6b..0819cda331 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -78,6 +78,9 @@ type blockDec struct { err error + // Check against this crc + checkCRC []byte + // Frame to use for singlethreaded decoding. // Should not be used by the decoder itself since parent may be another frame. localFrame *frameDec @@ -87,6 +90,7 @@ type blockDec struct { literals []byte seqData []byte sequence []seqVals + seqSize int // Size of uncompressed sequences } // Block is RLE, this is the size. @@ -515,6 +519,9 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } switch mode { case compModePredefined: + if seq.fse != nil && !seq.fse.preDefined { + fseDecoderPool.Put(seq.fse) + } seq.fse = &fsePredef[i] case compModeRLE: if br.remain() < 1 { @@ -522,34 +529,36 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } v := br.Uint8() br.advance(1) - dec := fseDecoderPool.Get().(*fseDecoder) + if seq.fse == nil || seq.fse.preDefined { + seq.fse = fseDecoderPool.Get().(*fseDecoder) + } symb, err := decSymbolValue(v, symbolTableX[i]) if err != nil { printf("RLE Transform table (%v) error: %v", tableIndex(i), err) return err } - dec.setRLE(symb) - seq.fse = dec + seq.fse.setRLE(symb) if debugDecoder { printf("RLE set to %+v, code: %v", symb, v) } case compModeFSE: println("Reading table for", tableIndex(i)) - dec := fseDecoderPool.Get().(*fseDecoder) - err := dec.readNCount(&br, uint16(maxTableSymbol[i])) + if seq.fse == nil || seq.fse.preDefined { + seq.fse = fseDecoderPool.Get().(*fseDecoder) + } + err := seq.fse.readNCount(&br, uint16(maxTableSymbol[i])) if err != nil { println("Read table error:", err) return err } - err = dec.transform(symbolTableX[i]) + err = seq.fse.transform(symbolTableX[i]) if err != nil { println("Transform table error:", err) return err } if debugDecoder { - println("Read table ok", "symbolLen:", dec.symbolLen) + println("Read table ok", "symbolLen:", seq.fse.symbolLen) } - seq.fse = dec case compModeRepeat: seq.repeat = true } @@ -579,7 +588,13 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } func (b *blockDec) decodeSequences(hist *history) error { - return hist.decoders.decode() + if hist.decoders.nSeqs == 0 { + return nil + } + hist.decoders.prevOffset = hist.recentOffsets + err := hist.decoders.decode() + hist.recentOffsets = hist.decoders.prevOffset + return err } func (b *blockDec) executeSequences(hist *history) error { @@ -591,7 +606,8 @@ func (b *blockDec) executeSequences(hist *history) error { hist.dict.content = nil } } - err := hist.decoders.execute(hist.b) + hist.decoders.windowSize = hist.windowSize + err := hist.decoders.execute(hbytes) if err != nil { return err } @@ -604,18 +620,20 @@ func (b *blockDec) updateHistory(hist *history) error { } // Set output and release references. b.dst = hist.decoders.out - hist.decoders.out, hist.decoders.literals = nil, nil - hist.recentOffsets = hist.decoders.prevOffset + if b.Last { // if last block we don't care about history. println("Last block, no history returned") hist.b = hist.b[:0] return nil + } else { + hist.append(b.dst) } - hist.append(b.dst) if debugDecoder { - println("Finished block with literals:", len(hist.decoders.literals), "and", len(hist.decoders.seq), "sequences.") + println("Finished block with ", len(hist.decoders.seq), "sequences.") } + hist.decoders.out, hist.decoders.literals = nil, nil + return nil } diff --git a/zstd/decoder.go b/zstd/decoder.go index b4aeb98e49..c9c703ef3a 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -5,10 +5,12 @@ package zstd import ( + "bytes" "context" - "errors" "io" "sync" + + "github.com/klauspost/compress/zstd/internal/xxhash" ) // Decoder provides decoding of zstandard streams. @@ -46,6 +48,9 @@ type decoderState struct { // cancel remaining output. cancel context.CancelFunc + // crc of current frame + crc *xxhash.Digest + flushed bool } @@ -79,7 +84,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { return nil, err } } - d.current.output = make(chan decodeOutput, d.o.concurrent) + d.current.crc = xxhash.New() d.current.flushed = true if r == nil { @@ -194,6 +199,7 @@ func (d *Decoder) Reset(r io.Reader) error { return nil } + d.current.output = make(chan decodeOutput, d.o.concurrent) ctx, cancel := context.WithCancel(context.Background()) d.streamWg.Add(1) go d.startStreamDecoder(ctx, r, d.current.output) @@ -234,12 +240,9 @@ func (d *Decoder) drainOutput() { } d.decoders <- v.d } - if v.err == errEndOfStream { - println("current flushed") - d.current.flushed = true - return - } } + d.current.output = nil + d.current.flushed = true } // WriteTo writes data to w until there's no more data to write or when an error occurs. @@ -380,6 +383,38 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { return false } } + next := d.current.decodeOutput + if next.d != nil && next.d.async.newHist != nil { + d.current.crc.Reset() + } + if len(next.b) > 0 { + n, err := d.current.crc.Write(next.b) + if err == nil { + if n != len(next.b) { + d.current.err = io.ErrShortWrite + } + } + } + if next.err == nil && next.d != nil && len(next.d.checkCRC) != 0 { + got := d.current.crc.Sum64() + var tmp [4]byte + // Flip to match file order. + tmp[0] = byte(got >> 0) + tmp[1] = byte(got >> 8) + tmp[2] = byte(got >> 16) + tmp[3] = byte(got >> 24) + + if !bytes.Equal(tmp[:], next.d.checkCRC) { + if debugDecoder { + println("CRC Check Failed:", tmp[:], "!=", next.d.checkCRC) + } + d.current.err = ErrCRCMismatch + } else { + if debugDecoder { + println("CRC ok", tmp[:]) + } + } + } if debugDecoder { println("got", len(d.current.b), "bytes, error:", d.current.err) } @@ -457,12 +492,10 @@ type decodeStream struct { cancel chan struct{} } -// errEndOfStream indicates that everything from the stream was read. -var errEndOfStream = errors.New("end-of-stream") - // Create Decoder: // ASYNC: -// Spawn 3 go routines. +// Spawn 4 go routines. +// 0: Read frames and decode blocks. // 1: Decode block and literals. Receives hufftree and seqdecs, returns seqdecs and huff tree. // 2: Wait for recentOffsets if needed. Decode sequences, send recentOffsets. // 3: Wait for stream history, execute sequences, send stream history. @@ -472,9 +505,6 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch br := readerWrapper{r: r} - // TODO: Needed? - frame.initAsync() - var seqPrepare = make(chan *blockDec, d.o.concurrent) var seqDecode = make(chan *blockDec, d.o.concurrent) var seqExecute = make(chan *blockDec, d.o.concurrent) @@ -482,13 +512,20 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch go func() { var hist history for block := range seqPrepare { - if block.err != nil { + if block.async.newHist != nil { + if debugDecoder { + println("Async 1: new history") + } + hist.reset() + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + } + if block.err != nil || block.Type != blockTypeCompressed { seqDecode <- block continue } - if block.async.newHist != nil { - hist.huffTree = block.async.newHist.huffTree - } + remain, err := block.decodeLiterals(block.data, &hist) block.err = err if err == nil { @@ -504,16 +541,34 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch var hist history for block := range seqDecode { - if block.err != nil { - seqExecute <- block - continue - } if block.async.newHist != nil { + if debugDecoder { + println("Async 2: new history, recent:", block.async.newHist.recentOffsets) + } hist.decoders = block.async.newHist.decoders hist.recentOffsets = block.async.newHist.recentOffsets + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } } + if block.err != nil || block.Type != blockTypeCompressed { + seqExecute <- block + continue + } + hist.decoders.literals = block.async.literals - block.err = block.decodeSequences(&hist) + block.err = block.prepareSequences(block.async.seqData, &hist) + if debugDecoder && block.err != nil { + println("prepareSequences returned:", block.err) + } + if block.err == nil { + block.err = block.decodeSequences(&hist) + if debugDecoder && block.err != nil { + println("decodeSequences returned:", block.err) + } + block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs] + block.async.seqSize = hist.decoders.seqSize + } seqExecute <- block } close(seqExecute) @@ -529,22 +584,71 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch continue } if block.async.newHist != nil { - hist.dict = block.async.newHist.dict - hist.b = append(hist.b[:0], block.async.newHist.b...) + if debugDecoder { + println("Async 3: new history") + } + hist.windowSize = block.async.newHist.windowSize + hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + if cap(hist.b) < hist.allocFrameBuffer { + hist.b = make([]byte, 0, hist.allocFrameBuffer) + println("Alloc history sized", hist.allocFrameBuffer) + } + hist.b = hist.b[:0] } - + do := decodeOutput{err: block.err, d: block} + switch block.Type { + case blockTypeRLE: + if cap(block.dst) < int(block.RLESize) { + if block.lowMem { + block.dst = make([]byte, block.RLESize) + } else { + block.dst = make([]byte, maxBlockSize) + } + } + block.dst = block.dst[:block.RLESize] + v := block.data[0] + for i := range block.dst { + block.dst[i] = v + } + hist.append(block.dst) + do.b = block.dst + case blockTypeRaw: + hist.append(block.data) + do.b = block.data + case blockTypeCompressed: + if debugDecoder { + println("execute with history length:", len(hist.b), "window:", hist.windowSize) + } + hist.decoders.seq = block.async.sequence + hist.decoders.nSeqs = len(block.async.sequence) + hist.decoders.seqSize = block.async.seqSize + hist.decoders.literals = block.async.literals + do.err = block.executeSequences(&hist) + if debugDecoder && block.err != nil { + println("executeSequences returned:", block.err) + } + do.b = block.dst + } + output <- do } close(output) + if debugDecoder { + println("decoder goroutines finished") + } }() decodeStream: for { + if debugDecoder { + println("New frame...") + } var historySent bool frame.history.reset() err := frame.reset(&br) switch err { - case io.EOF: - break decodeStream default: dec := <-d.decoders dec.sendErr(err) @@ -577,27 +681,43 @@ decodeStream: case <-ctx.Done(): dec.sendErr(ctx.Err()) seqPrepare <- dec - break decodeFrame + break decodeStream default: } err := frame.next(dec) if !historySent { h := frame.history + if debugDecoder { + println("Alloc History:", h.allocFrameBuffer) + } dec.async.newHist = &h historySent = true + } else { + dec.async.newHist = nil } - seqPrepare <- dec - switch err { - case io.EOF: - // End of current frame, no error - println("EOF on next block") + dec.err = err + dec.checkCRC = nil + if dec.Last && frame.HasCheckSum && err == nil { + dec.checkCRC, err = frame.rawInput.readSmall(4) + if err != nil { + println("CRC missing?", err) + dec.err = err + } + if debugDecoder { + println("found crc to check:", dec.checkCRC) + } + } + select { + case <-ctx.Done(): break decodeFrame - case nil: - continue - default: - println("block decoder returned", err) + case seqPrepare <- dec: + } + if err != nil { break decodeStream } + if dec.Last { + break + } } } close(seqPrepare) diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 7faaf49219..ea7c5789de 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -202,11 +202,13 @@ func TestErrorWriter(t *testing.T) { func TestNewDecoder(t *testing.T) { defer timeout(60 * time.Second)() testDecoderFile(t, "testdata/decoder.zip") - dec, err := NewReader(nil) - if err != nil { - t.Fatal(err) + if true { + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + testDecoderDecodeAll(t, "testdata/decoder.zip", dec) } - testDecoderDecodeAll(t, "testdata/decoder.zip", dec) } func TestNewDecoderMemory(t *testing.T) { @@ -967,7 +969,7 @@ func testDecoderFile(t *testing.T, fn string) { want[tt.Name+".zst"], _ = ioutil.ReadAll(r) } - dec, err := NewReader(nil) + dec, err := NewReader(nil, WithDecoderConcurrency(1)) if err != nil { t.Error(err) return @@ -1430,7 +1432,7 @@ func testDecoderDecodeAll(t *testing.T, fn string, dec *Decoder) { wg.Add(1) t.Run("DecodeAll-"+tt.Name, func(t *testing.T) { defer wg.Done() - t.Parallel() + //t.Parallel() r, err := tt.Open() if err != nil { t.Fatal(err) diff --git a/zstd/framedec.go b/zstd/framedec.go index 133dea147a..f828d43916 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -8,15 +8,15 @@ import ( "bytes" "encoding/hex" "errors" - "hash" "io" "github.com/klauspost/compress/zstd/internal/xxhash" ) type frameDec struct { - o decoderOptions - crc hash.Hash64 + o decoderOptions + crc *xxhash.Digest + offset int64 WindowSize uint64 @@ -220,7 +220,7 @@ func (d *frameDec) reset(br byteBuffer) error { d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) } if debugDecoder { - println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize) + println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum) } } // Move this to shared. @@ -255,9 +255,10 @@ func (d *frameDec) reset(br byteBuffer) error { } d.history.windowSize = int(d.WindowSize) if d.o.lowMem && d.history.windowSize < maxBlockSize { - d.history.maxSize = d.history.windowSize * 2 + d.history.allocFrameBuffer = d.history.windowSize * 2 + // TODO: Maybe use FrameContent size } else { - d.history.maxSize = d.history.windowSize + maxBlockSize + d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize } // history contains input - maybe we do something d.rawInput = br @@ -312,24 +313,6 @@ func (d *frameDec) checkCRC() error { return nil } -func (d *frameDec) initAsync() { - if !d.o.lowMem && !d.SingleSegment { - // set max extra size history to 2MB. - d.history.maxSize = d.history.windowSize + maxBlockSize - } - // re-alloc if more than one extra block size. - if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize { - d.history.b = make([]byte, 0, d.history.maxSize) - } - if cap(d.history.b) < d.history.maxSize { - d.history.b = make([]byte, 0, d.history.maxSize) - } - if debugDecoder { - h := d.history - printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) - } -} - /* // startDecoder will start decoding blocks and write them to the writer. // The decoder will stop as soon as an error occurs or at end of frame. diff --git a/zstd/history.go b/zstd/history.go index 44f79fa9e2..aa08bd3da0 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -10,23 +10,20 @@ import ( // history contains the information transferred between blocks. type history struct { - // Needed first, if needed. + // Literal decompression huffTree *huff0.Scratch - // Needed second, if needed... - decoders sequenceDecs - - // Maybe needed... + // Sequence decompression + decoders sequenceDecs recentOffsets [3]int - // Needed last... - b []byte - bCh chan []byte + // History buffer... + b []byte - windowSize int - maxSize int - error bool - dict *dict + windowSize int + allocFrameBuffer int // needed? + error bool + dict *dict } // reset will reset the history to initial state of a frame. diff --git a/zstd/seqdec.go b/zstd/seqdec.go index a773e4ddcb..b39d6ef41d 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -161,7 +161,7 @@ func (s *sequenceDecs) decode() error { if temp == 0 { // 0 is not valid; input is corrupted; force offset to 1 - println("temp was 0") + println("WARNING: temp was 0") temp = 1 } @@ -249,6 +249,7 @@ func (s *sequenceDecs) decode() error { // execute will execute the decoded sequence with the provided history. // The sequence must be evaluated before being sent. func (s *sequenceDecs) execute(hist []byte) error { + // Ensure we have enough output size... if len(s.out)+s.seqSize > cap(s.out) { addBytes := s.seqSize + len(s.out) s.out = append(s.out, make([]byte, addBytes)...) @@ -271,7 +272,7 @@ func (s *sequenceDecs) execute(hist []byte) error { // we may be in dictionary. dictO := len(s.dict) - (mo - (len(s.out) + len(hist))) if dictO < 0 || dictO >= len(s.dict) { - return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", mo, len(s.out)+len(hist)+len(s.dict)) } end := dictO + ml if end > len(s.dict) { @@ -388,7 +389,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error { if temp == 0 { // 0 is not valid; input is corrupted; force offset to 1 - println("temp was 0") + println("WARNING: temp was 0") temp = 1 } diff --git a/zstd/zstd.go b/zstd/zstd.go index ef1d49a009..dbfad7813b 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -13,7 +13,7 @@ import ( ) // enable debug printing -const debug = false +const debug = true // enable encoding debug printing const debugEncoder = debug From 8bf62b778f997eb3c561db8056e90d818a9fc827 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 11:18:19 +0100 Subject: [PATCH 05/31] Tests pass. --- zstd/blockdec.go | 26 +++++++++++----- zstd/decoder.go | 66 ++++++++++++++++++++++++++++++++--------- zstd/decoder_options.go | 5 +++- zstd/decoder_test.go | 4 +-- zstd/encoder_test.go | 4 +-- zstd/framedec.go | 11 ++++--- zstd/seqdec.go | 33 ++++++++++----------- zstd/zstd.go | 4 +-- 8 files changed, 104 insertions(+), 49 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index 0819cda331..e2b22fd7b5 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -85,11 +85,12 @@ type blockDec struct { // Should not be used by the decoder itself since parent may be another frame. localFrame *frameDec + sequence []seqVals + async struct { newHist *history literals []byte seqData []byte - sequence []seqVals seqSize int // Size of uncompressed sequences } @@ -402,7 +403,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err } } huff := hist.huffTree - if huff == nil { + if huff == nil || (hist.dict != nil && huff == hist.dict.litEnc) { huff = huffDecoderPool.Get().(*huff0.Scratch) if huff == nil { huff = &huff0.Scratch{} @@ -573,6 +574,9 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } if nSeqs == 0 { + if len(b.sequence) > 0 { + b.sequence = b.sequence[:0] + } return nil } br := &bitReader{} @@ -588,11 +592,19 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } func (b *blockDec) decodeSequences(hist *history) error { + if cap(b.sequence) < hist.decoders.nSeqs { + if b.lowMem { + b.sequence = make([]seqVals, 0, hist.decoders.nSeqs) + } else { + b.sequence = make([]seqVals, 0, 0x7F00+0xffff) + } + } + b.sequence = b.sequence[:hist.decoders.nSeqs] if hist.decoders.nSeqs == 0 { return nil } hist.decoders.prevOffset = hist.recentOffsets - err := hist.decoders.decode() + err := hist.decoders.decode(b.sequence) hist.recentOffsets = hist.decoders.prevOffset return err } @@ -607,7 +619,7 @@ func (b *blockDec) executeSequences(hist *history) error { } } hist.decoders.windowSize = hist.windowSize - err := hist.decoders.execute(hbytes) + err := hist.decoders.execute(b.sequence, hbytes) if err != nil { return err } @@ -629,9 +641,9 @@ func (b *blockDec) updateHistory(hist *history) error { return nil } else { hist.append(b.dst) - } - if debugDecoder { - println("Finished block with ", len(hist.decoders.seq), "sequences.") + if debugDecoder { + println("Finished block with ", len(b.sequence), "sequences. Added", len(b.dst), "to history, now length", len(hist.b)) + } } hist.decoders.out, hist.decoders.literals = nil, nil diff --git a/zstd/decoder.go b/zstd/decoder.go index c9c703ef3a..f39d953650 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -7,6 +7,7 @@ package zstd import ( "bytes" "context" + "encoding/binary" "io" "sync" @@ -312,6 +313,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { if !ok { return nil, ErrUnknownDictionary } + if debugDecoder { + println("setting dict", frame.DictionaryID) + } frame.history.setDict(&dict) } if err != nil { @@ -375,18 +379,29 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { } if blocking { - d.current.decodeOutput = <-d.current.output + d.current.decodeOutput, ok = <-d.current.output } else { select { - case d.current.decodeOutput = <-d.current.output: + case d.current.decodeOutput, ok = <-d.current.output: default: return false } } + if !ok { + // This should not happen, so signal error state... + d.current.err = io.ErrUnexpectedEOF + return false + } next := d.current.decodeOutput if next.d != nil && next.d.async.newHist != nil { d.current.crc.Reset() } + if debugDecoder { + var tmp [4]byte + binary.LittleEndian.PutUint32(tmp[:], uint32(xxhash.Sum64(next.b))) + println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp) + } + if len(next.b) > 0 { n, err := d.current.crc.Write(next.b) if err == nil { @@ -406,7 +421,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { if !bytes.Equal(tmp[:], next.d.checkCRC) { if debugDecoder { - println("CRC Check Failed:", tmp[:], "!=", next.d.checkCRC) + println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)") } d.current.err = ErrCRCMismatch } else { @@ -415,9 +430,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { } } } - if debugDecoder { - println("got", len(d.current.b), "bytes, error:", d.current.err) - } + return true } @@ -511,7 +524,11 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch // Async 1: Prepare blocks... go func() { var hist history + var hasErr bool for block := range seqPrepare { + if hasErr { + continue + } if block.async.newHist != nil { if debugDecoder { println("Async 1: new history") @@ -522,15 +539,19 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } } if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil seqDecode <- block continue } remain, err := block.decodeLiterals(block.data, &hist) block.err = err + hasErr = block.err != nil if err == nil { block.async.literals = hist.decoders.literals block.async.seqData = remain + } else if debugDecoder { + println("decodeLiterals error:", err) } seqDecode <- block } @@ -539,8 +560,12 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch // Async 2: Decode sequences... go func() { var hist history + var hasErr bool for block := range seqDecode { + if hasErr { + continue + } if block.async.newHist != nil { if debugDecoder { println("Async 2: new history, recent:", block.async.newHist.recentOffsets) @@ -552,6 +577,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } } if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil seqExecute <- block continue } @@ -561,12 +587,14 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch if debugDecoder && block.err != nil { println("prepareSequences returned:", block.err) } + hasErr = block.err != nil if block.err == nil { block.err = block.decodeSequences(&hist) if debugDecoder && block.err != nil { println("decodeSequences returned:", block.err) } - block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs] + hasErr = block.err != nil + // block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs] block.async.seqSize = hist.decoders.seqSize } seqExecute <- block @@ -576,10 +604,11 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch // Async 3: Execute sequences... go func() { var hist history - + var hasErr bool for block := range seqExecute { out := decodeOutput{err: block.err, d: block} - if block.err != nil { + if block.err != nil || hasErr { + hasErr = true output <- out continue } @@ -601,6 +630,10 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch do := decodeOutput{err: block.err, d: block} switch block.Type { case blockTypeRLE: + if debugDecoder { + println("add rle block length:", block.RLESize) + } + if cap(block.dst) < int(block.RLESize) { if block.lowMem { block.dst = make([]byte, block.RLESize) @@ -616,19 +649,21 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch hist.append(block.dst) do.b = block.dst case blockTypeRaw: + if debugDecoder { + println("add raw block length:", len(block.data)) + } hist.append(block.data) do.b = block.data case blockTypeCompressed: if debugDecoder { println("execute with history length:", len(hist.b), "window:", hist.windowSize) } - hist.decoders.seq = block.async.sequence - hist.decoders.nSeqs = len(block.async.sequence) hist.decoders.seqSize = block.async.seqSize hist.decoders.literals = block.async.literals do.err = block.executeSequences(&hist) - if debugDecoder && block.err != nil { - println("executeSequences returned:", block.err) + hasErr = do.err != nil + if debugDecoder && hasErr { + println("executeSequences returned:", do.err) } do.b = block.dst } @@ -698,11 +733,14 @@ decodeStream: dec.err = err dec.checkCRC = nil if dec.Last && frame.HasCheckSum && err == nil { - dec.checkCRC, err = frame.rawInput.readSmall(4) + crc, err := frame.rawInput.readSmall(4) if err != nil { println("CRC missing?", err) dec.err = err } + var tmp [4]byte + copy(tmp[:], crc) + dec.checkCRC = tmp[:] if debugDecoder { println("found crc to check:", dec.checkCRC) } diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index 95cc9b8b81..3d9e1e178e 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -28,6 +28,9 @@ func (o *decoderOptions) setDefault() { concurrent: runtime.GOMAXPROCS(0), maxWindowSize: MaxWindowSize, } + if o.concurrent > 4 { + o.concurrent = 4 + } o.maxDecodedSize = 1 << 63 } @@ -40,7 +43,7 @@ func WithDecoderLowmem(b bool) DOption { // WithDecoderConcurrency will set the concurrency, // meaning the maximum number of decoders to run concurrently. // The value supplied must be at least 1. -// By default this will be set to GOMAXPROCS. +// By default this will be set to 4 or GOMAXPROCS, whatever is lower. func WithDecoderConcurrency(n int) DOption { return func(o *decoderOptions) error { if n <= 0 { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index ea7c5789de..4ca0ee1ae4 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -459,7 +459,7 @@ func TestNewDecoderBigFile(t *testing.T) { } defer f.Close() start := time.Now() - dec, err := NewReader(f) + dec, err := NewReader(f, WithDecoderConcurrency(4)) if err != nil { t.Fatal(err) } @@ -969,7 +969,7 @@ func testDecoderFile(t *testing.T, fn string) { want[tt.Name+".zst"], _ = ioutil.ReadAll(r) } - dec, err := NewReader(nil, WithDecoderConcurrency(1)) + dec, err := NewReader(nil) if err != nil { t.Error(err) return diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index b437225e4f..12bf207a15 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -497,8 +497,8 @@ func TestEncoder_EncoderHTML(t *testing.T) { } func TestEncoder_EncoderEnwik9(t *testing.T) { - testEncoderRoundtrip(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12}) - testEncoderRoundtripWriter(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12}) + //testEncoderRoundtrip(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12}) + //testEncoderRoundtripWriter(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12}) } // test roundtrip using io.ReaderFrom interface. diff --git a/zstd/framedec.go b/zstd/framedec.go index f828d43916..82df364b7a 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -219,10 +219,8 @@ func (d *frameDec) reset(br byteBuffer) error { d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24) d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) } - if debugDecoder { - println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum) - } } + // Move this to shared. d.HasCheckSum = fhd&(1<<2) != 0 if d.HasCheckSum { @@ -260,6 +258,11 @@ func (d *frameDec) reset(br byteBuffer) error { } else { d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize } + + if debugDecoder { + println("Frame: Dict:", d.DictionaryID, "FrameContentSize:", d.FrameContentSize, "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum) + } + // history contains input - maybe we do something d.rawInput = br return nil @@ -268,7 +271,7 @@ func (d *frameDec) reset(br byteBuffer) error { // next will start decoding the next block from stream. func (d *frameDec) next(block *blockDec) error { if debugDecoder { - printf("decoding new block %p:%p", block, block.data) + println("decoding new block") } err := block.reset(d.rawInput, d.WindowSize) if err != nil { diff --git a/zstd/seqdec.go b/zstd/seqdec.go index b39d6ef41d..3eab353589 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -68,12 +68,12 @@ type sequenceDecs struct { dict []byte literals []byte out []byte - seq []seqVals - nSeqs int - br *bitReader - seqSize int - windowSize int - maxBits uint8 + //seq []seqVals + nSeqs int + br *bitReader + seqSize int + windowSize int + maxBits uint8 } // initialize all 3 decoders from the stream input. @@ -100,21 +100,16 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) erro } // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decode() error { - seqs := s.nSeqs +func (s *sequenceDecs) decode(seqs []seqVals) error { br := s.br // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state - if cap(s.seq) < seqs { - s.seq = make([]seqVals, 0, seqs) - } - s.seq = s.seq[:seqs] s.seqSize = 0 litRemain := len(s.literals) - for i := range s.seq { + for i := range seqs { var ll, mo, ml int if br.off > 4+((maxOffsetBits+16+16)>>3) { // inlined function: @@ -202,12 +197,12 @@ func (s *sequenceDecs) decode() error { if litRemain < 0 { return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain) } - s.seq[i] = seqVals{ + seqs[i] = seqVals{ ll: ll, ml: ml, mo: mo, } - if i == len(s.seq)-1 { + if i == len(seqs)-1 { // This is the last sequence, so we shouldn't update state. break } @@ -248,7 +243,7 @@ func (s *sequenceDecs) decode() error { // execute will execute the decoded sequence with the provided history. // The sequence must be evaluated before being sent. -func (s *sequenceDecs) execute(hist []byte) error { +func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { // Ensure we have enough output size... if len(s.out)+s.seqSize > cap(s.out) { addBytes := s.seqSize + len(s.out) @@ -256,7 +251,11 @@ func (s *sequenceDecs) execute(hist []byte) error { s.out = s.out[:len(s.out)-addBytes] } - for _, seq := range s.seq { + if debugDecoder { + printf("Execute %d seqs with hist %d, dict %d, literals: %d bytes\n", len(seqs), len(hist), len(s.dict), len(s.literals)) + } + + for _, seq := range seqs { ll, ml, mo := seq.ll, seq.ml, seq.mo // Add literals diff --git a/zstd/zstd.go b/zstd/zstd.go index dbfad7813b..0b63559a3f 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -13,13 +13,13 @@ import ( ) // enable debug printing -const debug = true +const debug = false // enable encoding debug printing const debugEncoder = debug // enable decoding debug printing -const debugDecoder = debug +const debugDecoder = false // Enable extra assertions. const debugAsserts = debug || false From ac20ec2bc144e928094d974d6fc1531cc962ab8e Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 11:50:45 +0100 Subject: [PATCH 06/31] Avoid a few allocs --- zstd/blockdec.go | 6 +++++- zstd/decoder_test.go | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index e2b22fd7b5..eaf767309c 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -579,7 +579,10 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } return nil } - br := &bitReader{} + br := seqs.br + if br == nil { + br = &bitReader{} + } if err := br.init(in); err != nil { return err } @@ -619,6 +622,7 @@ func (b *blockDec) executeSequences(hist *history) error { } } hist.decoders.windowSize = hist.windowSize + hist.decoders.out = b.dst[:0] err := hist.decoders.execute(b.sequence, hbytes) if err != nil { return err diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 4ca0ee1ae4..0320313f33 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -459,10 +459,11 @@ func TestNewDecoderBigFile(t *testing.T) { } defer f.Close() start := time.Now() - dec, err := NewReader(f, WithDecoderConcurrency(4)) + dec, err := NewReader(f, WithDecoderConcurrency(4), WithDecoderLowmem(true)) if err != nil { t.Fatal(err) } + defer dec.Close() n, err := io.Copy(ioutil.Discard, dec) if err != nil { t.Fatal(err) From 9a9ac6b88fa46be7efc828ee418e71488fd8763a Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 15:38:15 +0100 Subject: [PATCH 07/31] Add stream decompression with no goroutines. --- zstd/blockdec.go | 1 + zstd/decoder.go | 125 +++++++++++++++++++++++++++++++--------- zstd/decoder_options.go | 9 ++- zstd/decoder_test.go | 54 +++++++++++------ zstd/history.go | 18 ++++++ zstd/seqdec.go | 14 ++--- zstd/zstd.go | 2 +- 7 files changed, 167 insertions(+), 56 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index eaf767309c..4d2fdac8c0 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -225,6 +225,7 @@ func (b *blockDec) decodeBuf(hist *history) error { return nil case blockTypeCompressed: saved := b.dst + // Append directly to history b.dst = hist.b hist.b = nil err := b.decodeCompressed(hist) diff --git a/zstd/decoder.go b/zstd/decoder.go index f39d953650..1c5b272977 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -29,6 +29,15 @@ type Decoder struct { // Current read position used for Reader functionality. current decoderState + // sync stream decoding + syncStream struct { + br readerWrapper + enabled bool + inFrame bool + } + + frame *frameDec + // Custom dictionaries. // Always uses copies. dicts map[uint32]dict @@ -134,7 +143,7 @@ func (d *Decoder) Read(p []byte) (int, error) { break } if !d.nextBlock(n == 0) { - return n, nil + return n, d.current.err } } } @@ -199,19 +208,26 @@ func (d *Decoder) Reset(r io.Reader) error { } return nil } - - d.current.output = make(chan decodeOutput, d.o.concurrent) - ctx, cancel := context.WithCancel(context.Background()) - d.streamWg.Add(1) - go d.startStreamDecoder(ctx, r, d.current.output) - + if d.frame == nil { + d.frame = newFrameDec(d.o) + } // Remove current block. + d.stashDecoder() d.current.decodeOutput = decodeOutput{} d.current.err = nil - d.current.cancel = cancel d.current.flushed = false d.current.d = nil + if d.o.concurrent == 1 { + return d.startSyncDecoder(r) + } + + d.current.output = make(chan decodeOutput, d.o.concurrent) + d.streamWg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + d.current.cancel = cancel + go d.startStreamDecoder(ctx, r, d.current.output) + return nil } @@ -366,6 +382,64 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { // If non-blocking mode is used the returned boolean will be false // if no data was available without blocking. func (d *Decoder) nextBlock(blocking bool) (ok bool) { + if d.current.err != nil { + // Keep error state. + return false + } + d.current.b = d.current.b[:0] + if d.syncStream.enabled { + if !blocking { + return false + } + if d.current.d == nil { + d.current.d = <-d.decoders + } + for len(d.current.b) == 0 { + if !d.syncStream.inFrame { + d.current.err = d.frame.reset(&d.syncStream.br) + if d.current.err != nil { + d.stashDecoder() + return false + } + d.syncStream.inFrame = true + } + d.current.err = d.frame.next(d.current.d) + if d.current.err != nil { + d.stashDecoder() + return false + } + d.frame.history.ensureBlock() + if debugDecoder { + println("history trimmed:", len(d.frame.history.b)) + } + histBefore := len(d.frame.history.b) + d.current.err = d.current.d.decodeBuf(&d.frame.history) + if d.current.err != nil { + d.stashDecoder() + return false + } + d.current.b = d.frame.history.b[histBefore:] + if debugDecoder { + println("history after:", len(d.frame.history.b)) + } + + // Update/Check CRC + if d.frame.HasCheckSum { + d.frame.crc.Write(d.current.b) + if d.current.d.Last { + d.current.err = d.frame.checkCRC() + if d.current.err != nil { + println("CRC error:", d.current.err) + d.stashDecoder() + return false + } + } + } + d.syncStream.inFrame = !d.current.d.Last + } + return true + } + if d.current.d != nil { if debugDecoder { printf("re-adding current decoder %p", d.current.d) @@ -373,10 +447,6 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { d.decoders <- d.current.d d.current.d = nil } - if d.current.err != nil { - // Keep error state. - return blocking - } if blocking { d.current.decodeOutput, ok = <-d.current.output @@ -413,12 +483,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { if next.err == nil && next.d != nil && len(next.d.checkCRC) != 0 { got := d.current.crc.Sum64() var tmp [4]byte - // Flip to match file order. - tmp[0] = byte(got >> 0) - tmp[1] = byte(got >> 8) - tmp[2] = byte(got >> 16) - tmp[3] = byte(got >> 24) - + binary.LittleEndian.PutUint32(tmp[:], uint32(got)) if !bytes.Equal(tmp[:], next.d.checkCRC) { if debugDecoder { println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)") @@ -434,6 +499,13 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { return true } +func (d *Decoder) stashDecoder() { + if d.current.d != nil { + d.decoders <- d.current.d + d.current.d = nil + } +} + // Close will release all resources. // It is NOT possible to reuse the decoder after this. func (d *Decoder) Close() { @@ -495,14 +567,12 @@ type decodeOutput struct { err error } -type decodeStream struct { - r io.Reader - - // Blocks ready to be written to output. - output chan decodeOutput - - // cancel reading from the input - cancel chan struct{} +func (d *Decoder) startSyncDecoder(r io.Reader) error { + d.frame.history.reset() + d.syncStream.br = readerWrapper{r: r} + d.syncStream.inFrame = false + d.syncStream.enabled = true + return nil } // Create Decoder: @@ -514,8 +584,6 @@ type decodeStream struct { // 3: Wait for stream history, execute sequences, send stream history. func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) { defer d.streamWg.Done() - frame := newFrameDec(d.o) - br := readerWrapper{r: r} var seqPrepare = make(chan *blockDec, d.o.concurrent) @@ -677,6 +745,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch decodeStream: for { + frame := d.frame if debugDecoder { println("New frame...") } diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index 3d9e1e178e..014b1d3de7 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -40,8 +40,13 @@ func WithDecoderLowmem(b bool) DOption { return func(o *decoderOptions) error { o.lowMem = b; return nil } } -// WithDecoderConcurrency will set the concurrency, -// meaning the maximum number of decoders to run concurrently. +// WithDecoderConcurrency will set the concurrency. +// When decoding block with DecodeAll, this will limit the number +// of possible concurrently running decodes. +// When decoding streams, this will limit the number of +// inflight blocks. +// When decoding streams and setting maximum to 1, +// no async decoding will be done. // The value supplied must be at least 1. // By default this will be set to 4 or GOMAXPROCS, whatever is lower. func WithDecoderConcurrency(n int) DOption { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 0320313f33..b233b9b724 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -200,14 +200,19 @@ func TestErrorWriter(t *testing.T) { } func TestNewDecoder(t *testing.T) { - defer timeout(60 * time.Second)() - testDecoderFile(t, "testdata/decoder.zip") - if true { - dec, err := NewReader(nil) - if err != nil { - t.Fatal(err) - } - testDecoderDecodeAll(t, "testdata/decoder.zip", dec) + for _, n := range []int{1, 4} { + t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) { + defer timeout(60 * time.Second)() + newFn := func() (*Decoder, error) { + return NewReader(nil, WithDecoderConcurrency(n)) + } + testDecoderFile(t, "testdata/decoder.zip", newFn) + dec, err := newFn() + if err != nil { + t.Fatal(err) + } + testDecoderDecodeAll(t, "testdata/decoder.zip", dec) + }) } } @@ -386,13 +391,20 @@ func TestNewDecoderFrameSize(t *testing.T) { } func TestNewDecoderGood(t *testing.T) { - defer timeout(30 * time.Second)() - testDecoderFile(t, "testdata/good.zip") - dec, err := NewReader(nil) - if err != nil { - t.Fatal(err) + for _, n := range []int{1, 4} { + t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) { + defer timeout(30 * time.Second)() + newFn := func() (*Decoder, error) { + return NewReader(nil, WithDecoderConcurrency(n)) + } + testDecoderFile(t, "testdata/good.zip", newFn) + dec, err := newFn() + if err != nil { + t.Fatal(err) + } + testDecoderDecodeAll(t, "testdata/good.zip", dec) + }) } - testDecoderDecodeAll(t, "testdata/good.zip", dec) } func TestNewDecoderBad(t *testing.T) { @@ -405,7 +417,10 @@ func TestNewDecoderBad(t *testing.T) { } func TestNewDecoderLarge(t *testing.T) { - testDecoderFile(t, "testdata/large.zip") + newFn := func() (*Decoder, error) { + return NewReader(nil) + } + testDecoderFile(t, "testdata/large.zip", newFn) dec, err := NewReader(nil) if err != nil { t.Fatal(err) @@ -435,7 +450,10 @@ func TestNewDecoderBig(t *testing.T) { t.Skip("To run extended tests, download https://files.klauspost.com/compress/zstd-10kfiles.zip \n" + "and place it in " + file + "\n" + "Running it requires about 5GB of RAM") } - testDecoderFile(t, file) + newFn := func() (*Decoder, error) { + return NewReader(nil) + } + testDecoderFile(t, file, newFn) dec, err := NewReader(nil) if err != nil { t.Fatal(err) @@ -948,7 +966,7 @@ func TestDecoderMultiFrameReset(t *testing.T) { } } -func testDecoderFile(t *testing.T, fn string) { +func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { data, err := ioutil.ReadFile(fn) if err != nil { t.Fatal(err) @@ -970,7 +988,7 @@ func testDecoderFile(t *testing.T, fn string) { want[tt.Name+".zst"], _ = ioutil.ReadAll(r) } - dec, err := NewReader(nil) + dec, err := newDec() if err != nil { t.Error(err) return diff --git a/zstd/history.go b/zstd/history.go index aa08bd3da0..3a29308310 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -89,6 +89,24 @@ func (h *history) append(b []byte) { copy(h.b[h.windowSize-len(b):], b) } +// ensureBlock will ensure there is space for at least one block... +func (h *history) ensureBlock() { + if cap(h.b) < h.allocFrameBuffer { + h.b = make([]byte, 0, h.allocFrameBuffer) + return + } + + avail := cap(h.b) - len(h.b) + if avail >= h.windowSize || avail > maxCompressedBlockSize { + return + } + // Move data down so we only have window size left. + // We know we have less than window size in b at this point. + discard := len(h.b) - h.windowSize + copy(h.b, h.b[discard:]) + h.b = h.b[:h.windowSize] +} + // append bytes to history without ever discarding anything. func (h *history) appendKeep(b []byte) { h.b = append(h.b, b...) diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 3eab353589..e7b9354940 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -263,6 +263,7 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { s.literals = s.literals[ll:] out := s.out + // Copy form dictionary... if mo > len(s.out)+len(hist) || mo > s.windowSize { if len(s.dict) == 0 { return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) @@ -279,14 +280,12 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { mo -= len(s.dict) - dictO ml -= len(s.dict) - dictO } else { - out = append(out, s.dict[dictO:end]...) - mo = 0 - ml = 0 + s.out = append(out, s.dict[dictO:end]...) + continue } } // Copy from history. - // TODO: Blocks without history could be made to ignore this completely. if v := mo - len(s.out); v > 0 { // v is the start position in history from end. start := len(hist) - v @@ -297,8 +296,8 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { mo -= v ml -= v } else { - out = append(out, hist[start:start+ml]...) - ml = 0 + s.out = append(out, hist[start:start+ml]...) + continue } } // We must be in current buffer now @@ -306,7 +305,8 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { start := len(s.out) - mo if ml <= len(s.out)-start { // No overlap - out = append(out, s.out[start:start+ml]...) + s.out = append(out, s.out[start:start+ml]...) + continue } else { // Overlapping copy // Extend destination slice and copy one byte at the time. diff --git a/zstd/zstd.go b/zstd/zstd.go index 0b63559a3f..ef1d49a009 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -19,7 +19,7 @@ const debug = false const debugEncoder = debug // enable decoding debug printing -const debugDecoder = false +const debugDecoder = debug // Enable extra assertions. const debugAsserts = debug || false From b2951ef8a40e4b16715afdd3e5e26ee9aff51b67 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 16:13:48 +0100 Subject: [PATCH 08/31] Check FrameContentSize and max decoded size. --- zstd/blockdec.go | 1 + zstd/decoder.go | 45 ++++++++++++++++++++-- zstd/framedec.go | 97 ------------------------------------------------ 3 files changed, 43 insertions(+), 100 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index 4d2fdac8c0..4618b90029 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -92,6 +92,7 @@ type blockDec struct { literals []byte seqData []byte seqSize int // Size of uncompressed sequences + fcs uint64 } // Block is RLE, this is the size. diff --git a/zstd/decoder.go b/zstd/decoder.go index 1c5b272977..f5755989ea 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -31,9 +31,11 @@ type Decoder struct { // sync stream decoding syncStream struct { - br readerWrapper - enabled bool - inFrame bool + decoded uint64 + decodedFrame uint64 + br readerWrapper + enabled bool + inFrame bool } frame *frameDec @@ -401,6 +403,16 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { d.stashDecoder() return false } + if d.frame.DictionaryID != nil { + dict, ok := d.dicts[*d.frame.DictionaryID] + if !ok { + d.current.err = ErrUnknownDictionary + return false + } else { + d.frame.history.setDict(&dict) + } + } + d.syncStream.decodedFrame = 0 d.syncStream.inFrame = true } d.current.err = d.frame.next(d.current.d) @@ -422,6 +434,16 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { if debugDecoder { println("history after:", len(d.frame.history.b)) } + d.syncStream.decoded += uint64(len(d.current.b)) + d.syncStream.decodedFrame += uint64(len(d.current.b)) + if d.syncStream.decoded > d.o.maxDecodedSize { + d.current.err = ErrDecoderSizeExceeded + return false + } + if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { + d.current.err = ErrFrameSizeExceeded + return false + } // Update/Check CRC if d.frame.HasCheckSum { @@ -572,6 +594,8 @@ func (d *Decoder) startSyncDecoder(r io.Reader) error { d.syncStream.br = readerWrapper{r: r} d.syncStream.inFrame = false d.syncStream.enabled = true + d.syncStream.decoded = 0 + d.syncStream.decodedFrame = 0 return nil } @@ -672,6 +696,8 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch // Async 3: Execute sequences... go func() { var hist history + var decoded, decodedFrame uint64 + var fcs uint64 var hasErr bool for block := range seqExecute { out := decodeOutput{err: block.err, d: block} @@ -694,6 +720,8 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch println("Alloc history sized", hist.allocFrameBuffer) } hist.b = hist.b[:0] + fcs = block.async.fcs + decodedFrame = 0 } do := decodeOutput{err: block.err, d: block} switch block.Type { @@ -735,6 +763,16 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } do.b = block.dst } + if !hasErr { + decoded += uint64(len(do.b)) + decodedFrame += uint64(len(do.b)) + if decoded > d.o.maxDecodedSize { + do.err = ErrDecoderSizeExceeded + } + if fcs > 0 && decodedFrame > fcs { + d.current.err = ErrFrameSizeExceeded + } + } output <- do } close(output) @@ -795,6 +833,7 @@ decodeStream: println("Alloc History:", h.allocFrameBuffer) } dec.async.newHist = &h + dec.async.fcs = frame.FrameContentSize historySent = true } else { dec.async.newHist = nil diff --git a/zstd/framedec.go b/zstd/framedec.go index 82df364b7a..c9a06d3185 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -316,103 +316,6 @@ func (d *frameDec) checkCRC() error { return nil } -/* -// startDecoder will start decoding blocks and write them to the writer. -// The decoder will stop as soon as an error occurs or at end of frame. -// When the frame has finished decoding the *bufio.Reader -// containing the remaining input will be sent on frameDec.frameDone. -func (d *frameDec) startDecoder(output chan decodeOutput) { - written := int64(0) - - defer func() { - d.asyncRunningMu.Lock() - d.asyncRunning = false - d.asyncRunningMu.Unlock() - - // Drain the currently decoding. - d.history.error = true - flushdone: - for { - select { - case b := <-d.decoding: - b.history <- &d.history - output <- <-b.result - default: - break flushdone - } - } - println("frame decoder done, signalling done") - d.frameDone.Done() - }() - // Get decoder for first block. - block := <-d.decoding - block.history <- &d.history - for { - var next *blockDec - // Get result - r := <-block.result - if r.err != nil { - println("Result contained error", r.err) - output <- r - return - } - if debugDecoder { - println("got result, from ", d.offset, "to", d.offset+int64(len(r.b))) - d.offset += int64(len(r.b)) - } - if !block.Last { - // Send history to next block - select { - case next = <-d.decoding: - if debugDecoder { - println("Sending ", len(d.history.b), "bytes as history") - } - next.history <- &d.history - default: - // Wait until we have sent the block, so - // other decoders can potentially get the decoder. - next = nil - } - } - - // Add checksum, async to decoding. - if d.HasCheckSum { - n, err := d.crc.Write(r.b) - if err != nil { - r.err = err - if n != len(r.b) { - r.err = io.ErrShortWrite - } - output <- r - return - } - } - written += int64(len(r.b)) - if d.SingleSegment && uint64(written) > d.FrameContentSize { - println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize) - r.err = ErrFrameSizeExceeded - output <- r - return - } - if block.Last { - r.err = d.checkCRC() - output <- r - return - } - output <- r - if next == nil { - // There was no decoder available, we wait for one now that we have sent to the writer. - if debugDecoder { - println("Sending ", len(d.history.b), " bytes as history") - } - next = <-d.decoding - next.history <- &d.history - } - block = next - } -} -*/ - // runDecoder will create a sync decoder that will decode a block of data. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { saved := d.history.b From 9d525e5e96841810a09672d17dd306ff9c069f96 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 16:48:27 +0100 Subject: [PATCH 09/31] Remove unused var+func. --- zstd/framedec.go | 2 -- zstd/seqdec.go | 33 --------------------------------- 2 files changed, 35 deletions(-) diff --git a/zstd/framedec.go b/zstd/framedec.go index c9a06d3185..7e20a8fb49 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -17,8 +17,6 @@ type frameDec struct { o decoderOptions crc *xxhash.Digest - offset int64 - WindowSize uint64 // Frame history passed between blocks diff --git a/zstd/seqdec.go b/zstd/seqdec.go index e7b9354940..e82a35f939 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -694,36 +694,3 @@ func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int { s.prevOffset[0] = temp return temp } - -// mergeHistory will merge history. -func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) { - for i := uint(0); i < 3; i++ { - var sNew, sHist *sequenceDec - switch i { - default: - // same as "case 0": - sNew = &s.litLengths - sHist = &hist.litLengths - case 1: - sNew = &s.offsets - sHist = &hist.offsets - case 2: - sNew = &s.matchLengths - sHist = &hist.matchLengths - } - if sNew.repeat { - if sHist.fse == nil { - return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i) - } - continue - } - if sNew.fse == nil { - return nil, fmt.Errorf("sequence stream %d, no fse found", i) - } - if sHist.fse != nil && !sHist.fse.preDefined { - fseDecoderPool.Put(sHist.fse) - } - sHist.fse = sNew.fse - } - return hist, nil -} From b74ecce15ccd6014eedb21a0d87feab182b59ff6 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Feb 2022 17:42:10 +0100 Subject: [PATCH 10/31] Tweaks and cleanup --- zstd/decoder.go | 3 + zstd/decoder_test.go | 148 +------------------------------------------ zstd/history.go | 2 +- zstd/seqdec.go | 55 ++++++++-------- 4 files changed, 32 insertions(+), 176 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index f5755989ea..aa73bb987c 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -313,6 +313,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } frame.rawInput = nil frame.bBuf = nil + if frame.history.decoders.br != nil { + frame.history.decoders.br.in = nil + } d.decoders <- block }() frame.bBuf = input diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index b233b9b724..a32e0e34a6 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -477,7 +477,7 @@ func TestNewDecoderBigFile(t *testing.T) { } defer f.Close() start := time.Now() - dec, err := NewReader(f, WithDecoderConcurrency(4), WithDecoderLowmem(true)) + dec, err := NewReader(f) if err != nil { t.Fatal(err) } @@ -1159,7 +1159,7 @@ func BenchmarkDecoder_DecodeAllParallel(b *testing.B) { if err != nil { b.Fatal(err) } - dec, err := NewReader(nil) + dec, err := NewReader(nil, WithDecoderConcurrency(runtime.GOMAXPROCS(0))) if err != nil { b.Fatal(err) return @@ -1199,150 +1199,6 @@ func BenchmarkDecoder_DecodeAllParallel(b *testing.B) { } } -/* -func BenchmarkDecoder_DecodeAllCgo(b *testing.B) { - fn := "testdata/benchdecoder.zip" - data, err := ioutil.ReadFile(fn) - if err != nil { - b.Fatal(err) - } - zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) - if err != nil { - b.Fatal(err) - } - for _, tt := range zr.File { - if !strings.HasSuffix(tt.Name, ".zst") { - continue - } - b.Run(tt.Name, func(b *testing.B) { - tt := tt - r, err := tt.Open() - if err != nil { - b.Fatal(err) - } - defer r.Close() - in, err := ioutil.ReadAll(r) - if err != nil { - b.Fatal(err) - } - got, err := zstd.Decompress(nil, in) - if err != nil { - b.Fatal(err) - } - b.SetBytes(int64(len(got))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - got, err = zstd.Decompress(got, in) - if err != nil { - b.Fatal(err) - } - } - }) - } -} - -func BenchmarkDecoder_DecodeAllParallelCgo(b *testing.B) { - fn := "testdata/benchdecoder.zip" - data, err := ioutil.ReadFile(fn) - if err != nil { - b.Fatal(err) - } - zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) - if err != nil { - b.Fatal(err) - } - for _, tt := range zr.File { - if !strings.HasSuffix(tt.Name, ".zst") { - continue - } - b.Run(tt.Name, func(b *testing.B) { - r, err := tt.Open() - if err != nil { - b.Fatal(err) - } - defer r.Close() - in, err := ioutil.ReadAll(r) - if err != nil { - b.Fatal(err) - } - got, err := zstd.Decompress(nil, in) - if err != nil { - b.Fatal(err) - } - b.SetBytes(int64(len(got))) - b.ReportAllocs() - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - got := make([]byte, len(got)) - for pb.Next() { - got, err = zstd.Decompress(got, in) - if err != nil { - b.Fatal(err) - } - } - }) - }) - } -} - -func BenchmarkDecoderSilesiaCgo(b *testing.B) { - fn := "testdata/silesia.tar.zst" - data, err := ioutil.ReadFile(fn) - if err != nil { - if os.IsNotExist(err) { - b.Skip("Missing testdata/silesia.tar.zst") - return - } - b.Fatal(err) - } - dec := zstd.NewReader(bytes.NewBuffer(data)) - n, err := io.Copy(ioutil.Discard, dec) - if err != nil { - b.Fatal(err) - } - - b.SetBytes(n) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - dec := zstd.NewReader(bytes.NewBuffer(data)) - _, err := io.CopyN(ioutil.Discard, dec, n) - if err != nil { - b.Fatal(err) - } - } -} -func BenchmarkDecoderEnwik9Cgo(b *testing.B) { - fn := "testdata/enwik9-1.zst" - data, err := ioutil.ReadFile(fn) - if err != nil { - if os.IsNotExist(err) { - b.Skip("Missing " + fn) - return - } - b.Fatal(err) - } - dec := zstd.NewReader(bytes.NewBuffer(data)) - n, err := io.Copy(ioutil.Discard, dec) - if err != nil { - b.Fatal(err) - } - - b.SetBytes(n) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - dec := zstd.NewReader(bytes.NewBuffer(data)) - _, err := io.CopyN(ioutil.Discard, dec, n) - if err != nil { - b.Fatal(err) - } - } -} - -*/ - func BenchmarkDecoderSilesia(b *testing.B) { fn := "testdata/silesia.tar.zst" data, err := ioutil.ReadFile(fn) diff --git a/zstd/history.go b/zstd/history.go index 3a29308310..e27177f23a 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -41,7 +41,7 @@ func (h *history) reset() { if f := h.decoders.matchLengths.fse; f != nil && !f.preDefined { fseDecoderPool.Put(f) } - h.decoders = sequenceDecs{} + h.decoders = sequenceDecs{br: h.decoders.br} if h.huffTree != nil { if h.dict == nil || h.dict.litEnc != h.huffTree { huffDecoderPool.Put(h.huffTree) diff --git a/zstd/seqdec.go b/zstd/seqdec.go index e82a35f939..037ba06a15 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -68,12 +68,11 @@ type sequenceDecs struct { dict []byte literals []byte out []byte - //seq []seqVals - nSeqs int - br *bitReader - seqSize int - windowSize int - maxBits uint8 + nSeqs int + br *bitReader + seqSize int + windowSize int + maxBits uint8 } // initialize all 3 decoders from the stream input. @@ -256,29 +255,27 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { } for _, seq := range seqs { - ll, ml, mo := seq.ll, seq.ml, seq.mo - // Add literals - s.out = append(s.out, s.literals[:ll]...) - s.literals = s.literals[ll:] + s.out = append(s.out, s.literals[:seq.ll]...) + s.literals = s.literals[seq.ll:] out := s.out // Copy form dictionary... - if mo > len(s.out)+len(hist) || mo > s.windowSize { + if seq.mo > len(s.out)+len(hist) || seq.mo > s.windowSize { if len(s.dict) == 0 { - return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, len(s.out)+len(hist)) } // we may be in dictionary. - dictO := len(s.dict) - (mo - (len(s.out) + len(hist))) + dictO := len(s.dict) - (seq.mo - (len(s.out) + len(hist))) if dictO < 0 || dictO >= len(s.dict) { - return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", mo, len(s.out)+len(hist)+len(s.dict)) + return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", seq.mo, len(s.out)+len(hist)+len(s.dict)) } - end := dictO + ml + end := dictO + seq.ml if end > len(s.dict) { out = append(out, s.dict[dictO:]...) - mo -= len(s.dict) - dictO - ml -= len(s.dict) - dictO + seq.mo -= len(s.dict) - dictO + seq.ml -= len(s.dict) - dictO } else { s.out = append(out, s.dict[dictO:end]...) continue @@ -286,34 +283,34 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { } // Copy from history. - if v := mo - len(s.out); v > 0 { + if v := seq.mo - len(s.out); v > 0 { // v is the start position in history from end. start := len(hist) - v - if ml > v { + if seq.ml > v { // Some goes into current block. // Copy remainder of history out = append(out, hist[start:]...) - mo -= v - ml -= v + seq.mo -= v + seq.ml -= v } else { - s.out = append(out, hist[start:start+ml]...) + s.out = append(out, hist[start:start+seq.ml]...) continue } } // We must be in current buffer now - if ml > 0 { - start := len(s.out) - mo - if ml <= len(s.out)-start { + if seq.ml > 0 { + start := len(s.out) - seq.mo + if seq.ml <= len(s.out)-start { // No overlap - s.out = append(out, s.out[start:start+ml]...) + s.out = append(out, s.out[start:start+seq.ml]...) continue } else { // Overlapping copy // Extend destination slice and copy one byte at the time. - out = out[:len(out)+ml] - src := out[start : start+ml] + out = out[:len(out)+seq.ml] + src := out[start : start+seq.ml] // Destination is the space we just added. - dst := out[len(out)-ml:] + dst := out[len(out)-seq.ml:] dst = dst[:len(src)] for i := range src { dst[i] = src[i] From e217e78eec871c9abc91b9ab6416020729ddf0ee Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sat, 19 Feb 2022 16:12:21 +0100 Subject: [PATCH 11/31] Use maxsize as documented. --- zstd/decoder.go | 22 ++++++++++------------ zstd/decoder_options.go | 12 ++++++++---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index aa73bb987c..9c626ff9bd 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -31,7 +31,6 @@ type Decoder struct { // sync stream decoding syncStream struct { - decoded uint64 decodedFrame uint64 br readerWrapper enabled bool @@ -415,6 +414,11 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { d.frame.history.setDict(&dict) } } + if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + d.current.err = ErrDecoderSizeExceeded + return false + } + d.syncStream.decodedFrame = 0 d.syncStream.inFrame = true } @@ -437,12 +441,8 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { if debugDecoder { println("history after:", len(d.frame.history.b)) } - d.syncStream.decoded += uint64(len(d.current.b)) + d.syncStream.decodedFrame += uint64(len(d.current.b)) - if d.syncStream.decoded > d.o.maxDecodedSize { - d.current.err = ErrDecoderSizeExceeded - return false - } if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { d.current.err = ErrFrameSizeExceeded return false @@ -597,7 +597,6 @@ func (d *Decoder) startSyncDecoder(r io.Reader) error { d.syncStream.br = readerWrapper{r: r} d.syncStream.inFrame = false d.syncStream.enabled = true - d.syncStream.decoded = 0 d.syncStream.decodedFrame = 0 return nil } @@ -699,7 +698,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch // Async 3: Execute sequences... go func() { var hist history - var decoded, decodedFrame uint64 + var decodedFrame uint64 var fcs uint64 var hasErr bool for block := range seqExecute { @@ -767,11 +766,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch do.b = block.dst } if !hasErr { - decoded += uint64(len(do.b)) decodedFrame += uint64(len(do.b)) - if decoded > d.o.maxDecodedSize { - do.err = ErrDecoderSizeExceeded - } if fcs > 0 && decodedFrame > fcs { d.current.err = ErrFrameSizeExceeded } @@ -812,6 +807,9 @@ decodeStream: frame.history.setDict(&dict) } } + if d.frame.FrameContentSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + err = ErrDecoderSizeExceeded + } if err != nil { output <- decodeOutput{ err: err, diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index 014b1d3de7..fd05c9bb01 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -40,21 +40,25 @@ func WithDecoderLowmem(b bool) DOption { return func(o *decoderOptions) error { o.lowMem = b; return nil } } -// WithDecoderConcurrency will set the concurrency. +// WithDecoderConcurrency sets the number of created decoders. // When decoding block with DecodeAll, this will limit the number // of possible concurrently running decodes. // When decoding streams, this will limit the number of // inflight blocks. // When decoding streams and setting maximum to 1, // no async decoding will be done. -// The value supplied must be at least 1. +// When a value of 0 is provided GOMAXPROCS will be used. // By default this will be set to 4 or GOMAXPROCS, whatever is lower. func WithDecoderConcurrency(n int) DOption { return func(o *decoderOptions) error { - if n <= 0 { + if n < 0 { return errors.New("concurrency must be at least 1") } - o.concurrent = n + if n == 0 { + o.concurrent = runtime.GOMAXPROCS(0) + } else { + o.concurrent = n + } return nil } } From 561b94cd88c33cabc5ee04d9b17b238993df0cfd Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 21 Feb 2022 11:56:21 +0100 Subject: [PATCH 12/31] Ensure history from frames cannot overlap. --- zstd/blockdec.go | 18 +++++++--- zstd/bytebuf.go | 3 ++ zstd/decoder.go | 16 +++++---- zstd/decoder_test.go | 75 +++++++++++++++++++++++++++++++++++++++--- zstd/framedec.go | 1 + zstd/history.go | 5 +++ zstd/seqdec.go | 3 +- zstd/testdata/bad.zip | Bin 3241 -> 3610 bytes 8 files changed, 104 insertions(+), 17 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index 4618b90029..c1d9757a2f 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -227,14 +227,22 @@ func (b *blockDec) decodeBuf(hist *history) error { case blockTypeCompressed: saved := b.dst // Append directly to history - b.dst = hist.b - hist.b = nil + if hist.ignoreBuffer == 0 { + b.dst = hist.b + hist.b = nil + } else { + b.dst = b.dst[:0] + } err := b.decodeCompressed(hist) if debugDecoder { println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err) } - hist.b = b.dst - b.dst = saved + if hist.ignoreBuffer == 0 { + hist.b = b.dst + b.dst = saved + } else { + hist.appendKeep(b.dst) + } return err case blockTypeReserved: // Used for returning errors. @@ -455,7 +463,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { b.dst = append(b.dst, hist.decoders.literals...) return nil } - err = hist.decoders.decodeSync(hist.b) + err = hist.decoders.decodeSync(hist) if err != nil { return err } diff --git a/zstd/bytebuf.go b/zstd/bytebuf.go index aab71c6cf8..b80191e4b1 100644 --- a/zstd/bytebuf.go +++ b/zstd/bytebuf.go @@ -113,6 +113,9 @@ func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) { func (r *readerWrapper) readByte() (byte, error) { n2, err := r.r.Read(r.tmp[:1]) if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return 0, err } if n2 != 1 { diff --git a/zstd/decoder.go b/zstd/decoder.go index 9c626ff9bd..65cde590b9 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -322,11 +322,14 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { for { frame.history.reset() err := frame.reset(&frame.bBuf) - if err == io.EOF { - if debugDecoder { - println("frame reset return EOF") + if err != nil { + if err == io.EOF { + if debugDecoder { + println("frame reset return EOF") + } + return dst, nil } - return dst, nil + return dst, err } if frame.DictionaryID != nil { dict, ok := d.dicts[*frame.DictionaryID] @@ -338,9 +341,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } frame.history.setDict(&dict) } - if err != nil { - return dst, err - } + if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { return dst, ErrDecoderSizeExceeded } @@ -400,6 +401,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { } for len(d.current.b) == 0 { if !d.syncStream.inFrame { + d.frame.history.reset() d.current.err = d.frame.reset(&d.syncStream.br) if d.current.err != nil { d.stashDecoder() diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index a32e0e34a6..acc7028722 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -408,6 +408,21 @@ func TestNewDecoderGood(t *testing.T) { } func TestNewDecoderBad(t *testing.T) { + if true { + t.Run("conc-4", func(t *testing.T) { + newFn := func() (*Decoder, error) { + return NewReader(nil, WithDecoderConcurrency(4)) + } + testDecoderFileBad(t, "testdata/bad.zip", newFn) + + }) + t.Run("conc-1", func(t *testing.T) { + newFn := func() (*Decoder, error) { + return NewReader(nil, WithDecoderConcurrency(1)) + } + testDecoderFileBad(t, "testdata/bad.zip", newFn) + }) + } defer timeout(10 * time.Second)() dec, err := NewReader(nil) if err != nil { @@ -1041,6 +1056,59 @@ func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { } } +func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error)) { + data, err := ioutil.ReadFile(fn) + if err != nil { + t.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + t.Fatal(err) + } + var want = make(map[string][]byte) + for _, tt := range zr.File { + if strings.HasSuffix(tt.Name, ".zst") { + continue + } + r, err := tt.Open() + if err != nil { + t.Fatal(err) + return + } + want[tt.Name+".zst"], _ = ioutil.ReadAll(r) + } + + dec, err := newDec() + if err != nil { + t.Error(err) + return + } + defer dec.Close() + for i, tt := range zr.File { + if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) { + continue + } + t.Run("Reader-"+tt.Name, func(t *testing.T) { + r, err := tt.Open() + if err != nil { + t.Error(err) + return + } + defer r.Close() + err = dec.Reset(r) + if err != nil { + t.Error(err) + return + } + _, err = ioutil.ReadAll(dec) + if err == nil { + t.Error("Did not get expected error") + } + t.Log("get error", err) + }) + } +} + func BenchmarkDecoder_DecoderSmall(b *testing.B) { fn := "testdata/benchdecoder.zip" data, err := ioutil.ReadFile(fn) @@ -1372,7 +1440,6 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) { wg.Add(1) t.Run("DecodeAll-"+tt.Name, func(t *testing.T) { defer wg.Done() - t.Parallel() r, err := tt.Open() if err != nil { t.Fatal(err) @@ -1381,10 +1448,10 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) { if err != nil { t.Fatal(err) } - // make a buffer that is too small. - _, err = dec.DecodeAll(in, make([]byte, 0, 200)) + // make a buffer that is small. + got, err := dec.DecodeAll(in, make([]byte, 0, 20)) if err == nil { - t.Error("Did not get expected error") + t.Error("Did not get expected error, got", len(got), "bytes") } }) } diff --git a/zstd/framedec.go b/zstd/framedec.go index 7e20a8fb49..fc4c25d9c6 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -320,6 +320,7 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { // We use the history for output to avoid copying it. d.history.b = dst + d.history.ignoreBuffer = len(dst) // Store input length, so we only check new data. crcStart := len(dst) var err error diff --git a/zstd/history.go b/zstd/history.go index e27177f23a..c96687aada 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -20,6 +20,10 @@ type history struct { // History buffer... b []byte + // ignoreBuffer is meant to ignore a number of bytes + // when checking for matches in history + ignoreBuffer int + windowSize int allocFrameBuffer int // needed? error bool @@ -30,6 +34,7 @@ type history struct { // The history must already have been initialized to the desired size. func (h *history) reset() { h.b = h.b[:0] + h.ignoreBuffer = 0 h.error = false h.recentOffsets = [3]int{1, 4, 8} if f := h.decoders.litLengths.fse; f != nil && !f.preDefined { diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 037ba06a15..54b88c327c 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -326,13 +326,14 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { } // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decodeSync(hist []byte) error { +func (s *sequenceDecs) decodeSync(history *history) error { br := s.br seqs := s.nSeqs startSize := len(s.out) // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state + hist := history.b[history.ignoreBuffer:] for i := seqs - 1; i >= 0; i-- { if br.overread() { diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index 32ae49986601a8db1eec2bcc50da7a1d6b3545e1..b61619a3af3547e167e73499071485539f549482 100644 GIT binary patch delta 402 zcmZ1}IZI|kB3nH(iwFY~0|&!buh0+;myMCi3=9k=K&%7AhDK?o#^#18hK80FNk&FV z$p(gLsfGq=mMLkLW~LVADW+*=7AZ+adR4_Gmv5~4KmFi>>5Yzu8CYgWUYRLwpY$Xo z-Q=>vBt!O-5)uxn3<2IS%LGC~Lw;Ry4`XAnW?*1}+LmgTXqISbYGIU^W@M0PkZPK0 zXqskVZjfwdVwPl@l45FPnQUQXnFO;#AS)is0|0WI1d;#% From 3db2dcb2fce307d592206cfdfe35905d40a169b6 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 21 Feb 2022 12:16:57 +0100 Subject: [PATCH 13/31] Fix deadlock on error. --- zstd/decoder.go | 2 +- zstd/decoder_test.go | 20 +++++++++++--------- zstd/testdata/bad.zip | Bin 3610 -> 4142 bytes 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 65cde590b9..3203eb822e 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -816,7 +816,7 @@ decodeStream: output <- decodeOutput{ err: err, } - return + break decodeStream } decodeFrame: // Go through all blocks of the frame. diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index acc7028722..0d040f406f 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -409,26 +409,28 @@ func TestNewDecoderGood(t *testing.T) { func TestNewDecoderBad(t *testing.T) { if true { - t.Run("conc-4", func(t *testing.T) { + t.Run("Reader-4", func(t *testing.T) { newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(4)) } testDecoderFileBad(t, "testdata/bad.zip", newFn) }) - t.Run("conc-1", func(t *testing.T) { + t.Run("Reader-1", func(t *testing.T) { newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(1)) } testDecoderFileBad(t, "testdata/bad.zip", newFn) }) } - defer timeout(10 * time.Second)() - dec, err := NewReader(nil) - if err != nil { - t.Fatal(err) - } - testDecoderDecodeAllError(t, "testdata/bad.zip", dec) + t.Run("DecodeAll", func(t *testing.T) { + defer timeout(10 * time.Second)() + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + testDecoderDecodeAllError(t, "testdata/bad.zip", dec) + }) } func TestNewDecoderLarge(t *testing.T) { @@ -1438,7 +1440,7 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) { continue } wg.Add(1) - t.Run("DecodeAll-"+tt.Name, func(t *testing.T) { + t.Run(tt.Name, func(t *testing.T) { defer wg.Done() r, err := tt.Open() if err != nil { diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index b61619a3af3547e167e73499071485539f549482..de9004186c8f8136553bab386e500c47e27d6f1d 100644 GIT binary patch delta 567 zcmbOwvrb`y9D6-8iwFY~0|�gwT+Fp;aa#3=9lHK&%7ACI%*n#)+mDX~{`R7D+ zX8I$b>E-=fLOFq^azRZ`Ha0Ldt~WJIHa9m&O*S<(H8)K)Gqo^COg1vJG%++auuQZt zHZV+t8*hGsfd$FT$3Qbpw%u>!2AanQH8Um2#L(Q(z#zrY#KOqb$iM>V(v%dS^OG%1 z42{h!k_}Rf(h`j=VJ;45J$m3a1Jh<{_SbCnj7%a7xB~-d0}vJfKDcsk1 Date: Mon, 21 Feb 2022 15:30:30 +0100 Subject: [PATCH 14/31] Stricter framecontent size checks and consistency. --- zstd/blockdec.go | 15 +++++++++++++++ zstd/decoder.go | 33 ++++++++++++++++++++++++++++++--- zstd/decoder_test.go | 39 ++++++++++++++++++++++++++++++--------- zstd/framedec.go | 20 ++++++++++++++++++-- zstd/seqdec.go | 4 ++-- zstd/testdata/bad.zip | Bin 4142 -> 4890 bytes zstd/testdata/good.zip | Bin 3914 -> 15437 bytes zstd/zstd.go | 4 ++++ 8 files changed, 99 insertions(+), 16 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index c1d9757a2f..c55b64c60a 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -141,6 +141,13 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { case blockTypeReserved: return ErrReservedBlockType case blockTypeRLE: + if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) { + if debugDecoder { + printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) + } + // TODO: Likely enable: + //return ErrWindowSizeExceeded + } b.RLESize = uint32(cSize) if b.lowMem { maxSize = cSize @@ -162,6 +169,14 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { return ErrCompressedSizeTooBig } case blockTypeRaw: + if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) { + if debugDecoder { + printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) + } + // TODO: Likely enable: + //return ErrWindowSizeExceeded + } + b.RLESize = 0 // We do not need a destination for raw blocks. maxSize = -1 diff --git a/zstd/decoder.go b/zstd/decoder.go index 3203eb822e..3158e516fb 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -392,6 +392,8 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { return false } d.current.b = d.current.b[:0] + + // SYNC: if d.syncStream.enabled { if !blocking { return false @@ -431,12 +433,14 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { } d.frame.history.ensureBlock() if debugDecoder { - println("history trimmed:", len(d.frame.history.b)) + println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame) } histBefore := len(d.frame.history.b) d.current.err = d.current.d.decodeBuf(&d.frame.history) + if d.current.err != nil { d.stashDecoder() + println("error after:", d.current.err) return false } d.current.b = d.frame.history.b[histBefore:] @@ -444,9 +448,24 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { println("history after:", len(d.frame.history.b)) } + // Check frame size (before CRC) d.syncStream.decodedFrame += uint64(len(d.current.b)) if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } d.current.err = ErrFrameSizeExceeded + d.stashDecoder() + return false + } + + // Check FCS + if d.current.d.Last && d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame != d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeMismatch + d.stashDecoder() return false } @@ -467,6 +486,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { return true } + //ASYNC: if d.current.d != nil { if debugDecoder { printf("re-adding current decoder %p", d.current.d) @@ -770,7 +790,14 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch if !hasErr { decodedFrame += uint64(len(do.b)) if fcs > 0 && decodedFrame > fcs { - d.current.err = ErrFrameSizeExceeded + println("fcs exceeded", block.Last, fcs, decodedFrame) + do.err = ErrFrameSizeExceeded + hasErr = true + } else if block.Last && fcs > 0 && decodedFrame != fcs { + do.err = ErrFrameSizeMismatch + hasErr = true + } else { + println("fcs ok", block.Last, fcs, decodedFrame) } } output <- do @@ -809,7 +836,7 @@ decodeStream: frame.history.setDict(&dict) } } - if d.frame.FrameContentSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + if d.frame.WindowSize > d.o.maxWindowSize { err = ErrDecoderSizeExceeded } if err != nil { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 0d040f406f..884afc0cd4 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -408,19 +408,20 @@ func TestNewDecoderGood(t *testing.T) { } func TestNewDecoderBad(t *testing.T) { + var errMap = make(map[string]string) if true { t.Run("Reader-4", func(t *testing.T) { newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(4)) } - testDecoderFileBad(t, "testdata/bad.zip", newFn) + testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap) }) t.Run("Reader-1", func(t *testing.T) { newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(1)) } - testDecoderFileBad(t, "testdata/bad.zip", newFn) + testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap) }) } t.Run("DecodeAll", func(t *testing.T) { @@ -429,7 +430,7 @@ func TestNewDecoderBad(t *testing.T) { if err != nil { t.Fatal(err) } - testDecoderDecodeAllError(t, "testdata/bad.zip", dec) + testDecoderDecodeAllError(t, "testdata/bad.zip", dec, errMap) }) } @@ -1058,7 +1059,7 @@ func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { } } -func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error)) { +func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error), errMap map[string]string) { data, err := ioutil.ReadFile(fn) if err != nil { t.Fatal(err) @@ -1090,7 +1091,7 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) { continue } - t.Run("Reader-"+tt.Name, func(t *testing.T) { + t.Run(tt.Name, func(t *testing.T) { r, err := tt.Open() if err != nil { t.Error(err) @@ -1102,11 +1103,21 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) t.Error(err) return } - _, err = ioutil.ReadAll(dec) + got, err := ioutil.ReadAll(dec) if err == nil { - t.Error("Did not get expected error") + t.Error("Did not get expected error, got ", len(got), "bytes") + return } - t.Log("get error", err) + if errMap[tt.Name] == "" { + errMap[tt.Name] = err.Error() + } else { + want := errMap[tt.Name] + if want != err.Error() { + t.Errorf("error mismatch, prev run got %s, now got %s", want, err.Error()) + } + return + } + t.Log("got error", err) }) } } @@ -1423,7 +1434,7 @@ func testDecoderDecodeAll(t *testing.T, fn string, dec *Decoder) { }() } -func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) { +func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder, errMap map[string]string) { data, err := ioutil.ReadFile(fn) if err != nil { t.Fatal(err) @@ -1454,6 +1465,16 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) { got, err := dec.DecodeAll(in, make([]byte, 0, 20)) if err == nil { t.Error("Did not get expected error, got", len(got), "bytes") + return + } + if errMap[tt.Name] == "" { + t.Error("cannot check error") + } else { + want := errMap[tt.Name] + if want != err.Error() { + t.Errorf("error mismatch, prev run got %s, now got %s", want, err.Error()) + } + return } }) } diff --git a/zstd/framedec.go b/zstd/framedec.go index fc4c25d9c6..51a8773622 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -217,6 +217,9 @@ func (d *frameDec) reset(br byteBuffer) error { d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24) d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) } + if debugDecoder { + println("Read FCS:", d.FrameContentSize) + } } // Move this to shared. @@ -333,7 +336,7 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { println("next block:", dec) } err = dec.decodeBuf(&d.history) - if err != nil || dec.Last { + if err != nil { break } if uint64(len(d.history.b)) > d.o.maxDecodedSize { @@ -345,10 +348,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { err = ErrFrameSizeExceeded break } + if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) > d.FrameContentSize { + println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize) + err = ErrFrameSizeExceeded + break + } + if dec.Last { + break + } + if debugDecoder && d.FrameContentSize > 0 { + println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize) + } } dst = d.history.b if err == nil { - if d.HasCheckSum { + if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) != d.FrameContentSize { + err = ErrFrameSizeMismatch + } else if d.HasCheckSum { var n int n, err = d.crc.Write(dst[crcStart:]) if err == nil { diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 54b88c327c..5dcb8dbedc 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -237,7 +237,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { if err != nil { printf("Closing sequences: %v, %+v\n", err, *br) } - return nil + return err } // execute will execute the decoded sequence with the provided history. @@ -528,7 +528,7 @@ func (s *sequenceDecs) decodeSync(history *history) error { // Add final literals s.out = append(s.out, s.literals...) - return nil + return br.close() } // update states, at least 27 bits must be available. diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index de9004186c8f8136553bab386e500c47e27d6f1d..43f82b82345808362d7ceb1a0612190eda711aa3 100644 GIT binary patch delta 784 zcmb8sF-ROi6oBDb_i~#Oww@MHpL1uS$c6K^#(wxN=wm6hQ)hmR0zzztOMz!)kiSZPC+Q7xENQ3MMAxX{v1c?yvUp{brO?~R?+ zYNW-Te4A8$Zl>$K=smu6g|IvbK`E9R>jzrO7y=G~@gtC^8ewAwpquRP4QJ6QW>btl zX{BGhPb%iwxgR1sc1>FaT+slei7X3MBtTH3H04oB6=x~8D#T;@k6v0BaC6^OwNC5D zy|0oB?}x6|R;$(N`uJ|QyFcOPT{rLLogOc@IB?|b3>N-$W==gIm2xh>Dzto&eKY_YjP!DJ^DM!cNh9}J5ej$uE`BO(nZL93YAL_@{QX= MxL)JkO&mi00&|bQuK)l5 delta 33 jcmbQGwoYM#3diOZoZr|*SlJkWKn@7Mu`@6zae#OLl)MH` diff --git a/zstd/testdata/good.zip b/zstd/testdata/good.zip index 1b4c16cf8ec44946ca9f02cae9674e00a8fdda64..177cb825ef1bda2933224c8fc6bedb5d4bd97e4a 100644 GIT binary patch literal 15437 zcmeHOZA@EL7(TZJgrZC~WH3=X#)Qo>ZhN7WX{0hSm1YV`vn<2mrM-npzeZsQ(Mi!Q z)9B)67A0ynvzP!`42fChe)vrc^5GUul>L|=L}L7i%#oQG1M@wc-g|pb@1px)HdLtyRM5F^lEl`MQg%UpaCHH6}RF3dip64m^D)U~yrt6pk(ZvF@@{6-y;WGYGxL zTWvyRmA%?gv!_=w@nW0T#G$6?t%2fC@?~NouH(g`kxf9@kV*{7IwaB8YOB)-EZ;n0 znVGI|Ssr@OeRg8+7Wd@Td~riACHh$9?}K*(FRntUB+)9jBD|r8Lq3B>c~^H| zPaqN%V?KX-#}l2w{wKYA+rl9e&!I!E%V|p$i(9qRlzhN#&Vo+spw?(RhQx4ROxhC@ zoQ9SIu7PYy>rmafM+*RfYpp&~QM&Figv`9jYBHdYj&rj!uKR|yGQM9}D9->A0g5{D zh~FP;wF>6KC`T!nyK8jk}Gz4@%DW=zy@Ck z+S--DB73E-pcoa7-xw`?B;&n|ua}$trbHFbw(OEc-3m-eHz}DymgYm;O9++WBjSzd zDdUMbE~}KY=k?B9kfo7fhb7y-G~mO9?+xMY5_|oTh)KJ(e{9ze*iBqFC2y6RmZYsm z2;<5Q=Tw5RSxpLaBbOIe2tZ9tht?(rsn4JeFB@Odcz>jXGAWGZ0PgLS{PmNAC+I zz+MWa&V={s&Q$E2+A+949*-Y>erRLPgx~P4a?vtTmfzF-^2MCN4Z!7svOrbyyhC#E z4pDM=Yk0p^6ofXP&+e-ft9Zfc@Cn|UO3BJQ^kB#=V18!!LjDN0h%M#-bAUO( z9AFMG2bcrQ0pFb9|e%mL;AbKrV+Ah7CC0A2;$jLz4d5}w1qm|=@Kz#L!>Fb9|e%mL;AbAUO( z9AFMG2bcrQf&bNk+95-un=|C0t4GDKJ`R}^GLup-BsiwRNpgZti4h!z-w%}n$Dsw4 zf&?n)JQ$`R3f|OGs0l8Wf~cTlUzh@Zo&fb1RFVpnLQe^u&{C#CQrYF<}} zAi+iVr}J9`XokmU$uPZ+WGnfL<6C5z;bDN_SR%|Nf02nuI?km+Hp&ofCvw(wq6@>w zY&!v_2Da4U2#$_+sqGGGaN3@!e&SRYUZm&e3U(#7r8Xr%I^jj2W_bLm1^!!CIl)A_ z@?bHU6vw#Kresn&@TE3w(crW_^}}5%_9S4hCF+%|N-xQUnaMPTkOTd%%#H4JoEn7; OY5~fDPn+!czkdK6Zb(J| delta 1156 zcmX?GaY}B(V&+s1hV+6tzIu(5qSzT27=(bB3y70)@{_aUa})C_i(44gB^Ve2yqQ@< z7?^-kAm#4UIGaF9L6{Fkxn5OqiN;p_zbeLF3=EFm3=B?}-fhd6Y``itc`I{h48*q1 z1Et0+Ky92*46&^!ClzQJnoar_8D@cO0AXGrhATx_1hm9)e#V2w%}gv?80+=cP50Zr z|5RfiP;n;^YXGrPVxnbAN@9wEk!4zcMVUm%NfvI7lagv3Jp-Gy7p=lap zCqvo)qb*mnTt(!g@-|67nm&2Z*tQhtnFTn7#sq6ypMh`2V1I2{T||FfcG+VW7;LIDKhDqBns_3+Piu z28POj{|h#2n&~l5wl`->Ku?cA|1vN#C;-XX%&) zE%g{LPkv=-#Q1Hptd$v)7VG3tAQj0vx!1~yX%_3`lUC+T$5|&cSz9wPvQ4(PHmlcV zV_--tO3Y1-FV4;^07WOla!A@hPw)uKYk_)+P5$VnLQ`BMJjo%94Fnbz#2O26Fe}g@ P3=A?r_yx#v-~jOeqn}`L diff --git a/zstd/zstd.go b/zstd/zstd.go index ef1d49a009..448c1c6ca7 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -75,6 +75,10 @@ var ( // This is only returned if SingleSegment is specified on the frame. ErrFrameSizeExceeded = errors.New("frame size exceeded") + // ErrFrameSizeMismatch is returned if the stated frame size does not match the expected size. + // This is only returned if SingleSegment is specified on the frame. + ErrFrameSizeMismatch = errors.New("frame size does not match on stream size") + // ErrCRCMismatch is returned if CRC mismatches. ErrCRCMismatch = errors.New("CRC check failed") From d6e790a95588b568b0d02eb29b059fec5a8051ff Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 21 Feb 2022 16:21:57 +0100 Subject: [PATCH 15/31] Fix short test. --- zstd/decoder_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 884afc0cd4..1796d56bd2 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1087,10 +1087,7 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) return } defer dec.Close() - for i, tt := range zr.File { - if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) { - continue - } + for _, tt := range zr.File { t.Run(tt.Name, func(t *testing.T) { r, err := tt.Open() if err != nil { From 8b43a92759fb867db9bac6f9a4064ad68b93ab32 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 10:14:52 +0100 Subject: [PATCH 16/31] Add bench --- zstd/decoder_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 1796d56bd2..7aa21a18ec 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "io/fs" "io/ioutil" "log" "math/rand" @@ -1227,6 +1228,51 @@ func BenchmarkDecoder_DecodeAll(b *testing.B) { } } +func BenchmarkDecoder_DecodeAllFiles(b *testing.B) { + filepath.Walk("../testdata/", func(path string, info fs.FileInfo, err error) error { + if info.IsDir() || info.Size() < 100 { + return nil + } + b.Run(filepath.Base(path), func(b *testing.B) { + raw, err := ioutil.ReadFile(path) + if err != nil { + b.Error(err) + } + for i := SpeedFastest; i <= SpeedBestCompression; i++ { + b.Run(i.String(), func(b *testing.B) { + enc, err := NewWriter(nil, WithEncoderLevel(i), WithSingleSegment(true)) + if err != nil { + b.Error(err) + } + encoded := enc.EncodeAll(raw, nil) + if err != nil { + b.Error(err) + } + dec, err := NewReader(nil, WithDecoderConcurrency(1)) + if err != nil { + b.Error(err) + } + decoded, err := dec.DecodeAll(encoded, nil) + if err != nil { + b.Error(err) + } + b.SetBytes(int64(len(raw))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decoded, err = dec.DecodeAll(encoded, decoded[:0]) + if err != nil { + b.Error(err) + } + } + b.ReportMetric(100*float64(len(encoded))/float64(len(raw)), "pct") + }) + } + }) + return nil + }) +} + func BenchmarkDecoder_DecodeAllParallel(b *testing.B) { fn := "testdata/benchdecoder.zip" data, err := ioutil.ReadFile(fn) From 50a135c257802168c34c727ffd2315416f977b41 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 13:51:55 +0100 Subject: [PATCH 17/31] Reject big RLE/RAW blocks as per https://github.com/facebook/zstd/issues/3072 Clean up error returns to cancel when context is cancelled. --- zstd/blockdec.go | 6 ++---- zstd/decoder.go | 31 +++++++++++++++---------------- zstd/testdata/bad.zip | Bin 4890 -> 5089 bytes zstd/testdata/good.zip | Bin 15437 -> 2848 bytes 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/zstd/blockdec.go b/zstd/blockdec.go index c55b64c60a..f4bde47de1 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -145,8 +145,7 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { if debugDecoder { printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) } - // TODO: Likely enable: - //return ErrWindowSizeExceeded + return ErrWindowSizeExceeded } b.RLESize = uint32(cSize) if b.lowMem { @@ -173,8 +172,7 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { if debugDecoder { printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) } - // TODO: Likely enable: - //return ErrWindowSizeExceeded + return ErrWindowSizeExceeded } b.RLESize = 0 diff --git a/zstd/decoder.go b/zstd/decoder.go index 3158e516fb..54daab551c 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -817,18 +817,10 @@ decodeStream: var historySent bool frame.history.reset() err := frame.reset(&br) - switch err { - default: - dec := <-d.decoders - dec.sendErr(err) - seqPrepare <- dec - break decodeStream - case nil: - } if debugDecoder && err != nil { println("Frame decoder returned", err) } - if frame.DictionaryID != nil { + if err == nil && frame.DictionaryID != nil { dict, ok := d.dicts[*frame.DictionaryID] if !ok { err = ErrUnknownDictionary @@ -836,25 +828,29 @@ decodeStream: frame.history.setDict(&dict) } } - if d.frame.WindowSize > d.o.maxWindowSize { + if err == nil && d.frame.WindowSize > d.o.maxWindowSize { err = ErrDecoderSizeExceeded } if err != nil { - output <- decodeOutput{ - err: err, + select { + case <-ctx.Done(): + case dec := <-d.decoders: + dec.sendErr(err) + select { + case seqPrepare <- dec: + case <-ctx.Done(): + } } break decodeStream } decodeFrame: // Go through all blocks of the frame. for { - dec := <-d.decoders + var dec *blockDec select { case <-ctx.Done(): - dec.sendErr(ctx.Err()) - seqPrepare <- dec break decodeStream - default: + case dec = <-d.decoders: } err := frame.next(dec) if !historySent { @@ -868,6 +864,9 @@ decodeStream: } else { dec.async.newHist = nil } + if debugDecoder && err != nil { + println("next block returned error:", err) + } dec.err = err dec.checkCRC = nil if dec.Last && frame.HasCheckSum && err == nil { diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index 43f82b82345808362d7ceb1a0612190eda711aa3..e8ffc3d76ad2baede9ab8484d14ffd5aede39571 100644 GIT binary patch delta 251 zcmbQG_E3Gp3eI|F77+#}1`YqXbEe+BP z6OE0NQc}!POcG5Ej14VPjFT-*QVk6(^s0(WF5g)7e|lr%>%af@#yPBEJre1l!+Ipp zA&2!ysKc2C@eF1|h5&CyCJ_dZbqoyY1#^7$YFSb?@8a~~%!7$zcLq>75GVlgGEUMQ U72wUv2GXkxggbzGg}FdH03zQ#4FCWD delta 37 ocmaE;K1*%G3eL$BxuiDhbGdMeu(B}#ff5kLaWXJ017ZdS0KXRpBme*a diff --git a/zstd/testdata/good.zip b/zstd/testdata/good.zip index 177cb825ef1bda2933224c8fc6bedb5d4bd97e4a..f6c230bc5d5b3e9d2d4ff61de33f30f411c2bc1a 100644 GIT binary patch delta 33 kcmX?Gu|RCYI@ZnFW^t?{tZWQGAO?hStPBimfS7>+0GI*>H2?qr literal 15437 zcmeHOZA@EL7(TZJgrZC~WH3=X#)Qo>ZhN7WX{0hSm1YV`vn<2mrM-npzeZsQ(Mi!Q z)9B)67A0ynvzP!`42fChe)vrc^5GUul>L|=L}L7i%#oQG1M@wc-g|pb@1px)HdLtyRM5F^lEl`MQg%UpaCHH6}RF3dip64m^D)U~yrt6pk(ZvF@@{6-y;WGYGxL zTWvyRmA%?gv!_=w@nW0T#G$6?t%2fC@?~NouH(g`kxf9@kV*{7IwaB8YOB)-EZ;n0 znVGI|Ssr@OeRg8+7Wd@Td~riACHh$9?}K*(FRntUB+)9jBD|r8Lq3B>c~^H| zPaqN%V?KX-#}l2w{wKYA+rl9e&!I!E%V|p$i(9qRlzhN#&Vo+spw?(RhQx4ROxhC@ zoQ9SIu7PYy>rmafM+*RfYpp&~QM&Figv`9jYBHdYj&rj!uKR|yGQM9}D9->A0g5{D zh~FP;wF>6KC`T!nyK8jk}Gz4@%DW=zy@Ck z+S--DB73E-pcoa7-xw`?B;&n|ua}$trbHFbw(OEc-3m-eHz}DymgYm;O9++WBjSzd zDdUMbE~}KY=k?B9kfo7fhb7y-G~mO9?+xMY5_|oTh)KJ(e{9ze*iBqFC2y6RmZYsm z2;<5Q=Tw5RSxpLaBbOIe2tZ9tht?(rsn4JeFB@Odcz>jXGAWGZ0PgLS{PmNAC+I zz+MWa&V={s&Q$E2+A+949*-Y>erRLPgx~P4a?vtTmfzF-^2MCN4Z!7svOrbyyhC#E z4pDM=Yk0p^6ofXP&+e-ft9Zfc@Cn|UO3BJQ^kB#=V18!!LjDN0h%M#-bAUO( z9AFMG2bcrQ0pFb9|e%mL;AbKrV+Ah7CC0A2;$jLz4d5}w1qm|=@Kz#L!>Fb9|e%mL;AbAUO( z9AFMG2bcrQf&bNk+95-un=|C0t4GDKJ`R}^GLup-BsiwRNpgZti4h!z-w%}n$Dsw4 zf&?n)JQ$`R3f|OGs0l8Wf~cTlUzh@Zo&fb1RFVpnLQe^u&{C#CQrYF<}} zAi+iVr}J9`XokmU$uPZ+WGnfL<6C5z;bDN_SR%|Nf02nuI?km+Hp&ofCvw(wq6@>w zY&!v_2Da4U2#$_+sqGGGaN3@!e&SRYUZm&e3U(#7r8Xr%I^jj2W_bLm1^!!CIl)A_ z@?bHU6vw#Kresn&@TE3w(crW_^}}5%_9S4hCF+%|N-xQUnaMPTkOTd%%#H4JoEn7; OY5~fDPn+!czkdK6Zb(J| From d73139640c70bbc9c0d99494f08eed67f022c435 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 13:54:12 +0100 Subject: [PATCH 18/31] Use os.FileInfo for Go 1.15. --- zstd/decoder_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 7aa21a18ec..61614321ad 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -12,7 +12,6 @@ import ( "errors" "fmt" "io" - "io/fs" "io/ioutil" "log" "math/rand" @@ -1229,7 +1228,7 @@ func BenchmarkDecoder_DecodeAll(b *testing.B) { } func BenchmarkDecoder_DecodeAllFiles(b *testing.B) { - filepath.Walk("../testdata/", func(path string, info fs.FileInfo, err error) error { + filepath.Walk("../testdata/", func(path string, info os.FileInfo, err error) error { if info.IsDir() || info.Size() < 100 { return nil } From 03b301af11042d71164a1fd620f60ca0eec6b8e1 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 14:28:02 +0100 Subject: [PATCH 19/31] Break all on errors --- zstd/decoder.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 54daab551c..4bec887bd0 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -843,7 +843,7 @@ decodeStream: } break decodeStream } - decodeFrame: + // Go through all blocks of the frame. for { var dec *blockDec @@ -884,10 +884,10 @@ decodeStream: } select { case <-ctx.Done(): - break decodeFrame + break decodeStream case seqPrepare <- dec: } - if err != nil { + if dec.err != nil { break decodeStream } if dec.Last { From 4c5a306fffb79ea69748d00afea72d3c3aaec5a4 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 14:29:06 +0100 Subject: [PATCH 20/31] Move sync code to separate method. --- zstd/decoder.go | 172 ++++++++++++++++++++++--------------------- zstd/decoder_test.go | 6 +- 2 files changed, 90 insertions(+), 88 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 4bec887bd0..52313d2b4d 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -398,92 +398,11 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { if !blocking { return false } - if d.current.d == nil { - d.current.d = <-d.decoders + ok = d.nextBlockSync() + if !ok { + d.stashDecoder() } - for len(d.current.b) == 0 { - if !d.syncStream.inFrame { - d.frame.history.reset() - d.current.err = d.frame.reset(&d.syncStream.br) - if d.current.err != nil { - d.stashDecoder() - return false - } - if d.frame.DictionaryID != nil { - dict, ok := d.dicts[*d.frame.DictionaryID] - if !ok { - d.current.err = ErrUnknownDictionary - return false - } else { - d.frame.history.setDict(&dict) - } - } - if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { - d.current.err = ErrDecoderSizeExceeded - return false - } - - d.syncStream.decodedFrame = 0 - d.syncStream.inFrame = true - } - d.current.err = d.frame.next(d.current.d) - if d.current.err != nil { - d.stashDecoder() - return false - } - d.frame.history.ensureBlock() - if debugDecoder { - println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame) - } - histBefore := len(d.frame.history.b) - d.current.err = d.current.d.decodeBuf(&d.frame.history) - - if d.current.err != nil { - d.stashDecoder() - println("error after:", d.current.err) - return false - } - d.current.b = d.frame.history.b[histBefore:] - if debugDecoder { - println("history after:", len(d.frame.history.b)) - } - - // Check frame size (before CRC) - d.syncStream.decodedFrame += uint64(len(d.current.b)) - if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { - if debugDecoder { - printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) - } - d.current.err = ErrFrameSizeExceeded - d.stashDecoder() - return false - } - - // Check FCS - if d.current.d.Last && d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame != d.frame.FrameContentSize { - if debugDecoder { - printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) - } - d.current.err = ErrFrameSizeMismatch - d.stashDecoder() - return false - } - - // Update/Check CRC - if d.frame.HasCheckSum { - d.frame.crc.Write(d.current.b) - if d.current.d.Last { - d.current.err = d.frame.checkCRC() - if d.current.err != nil { - println("CRC error:", d.current.err) - d.stashDecoder() - return false - } - } - } - d.syncStream.inFrame = !d.current.d.Last - } - return true + return ok } //ASYNC: @@ -546,6 +465,89 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { return true } +func (d *Decoder) nextBlockSync() (ok bool) { + if d.current.d == nil { + d.current.d = <-d.decoders + } + for len(d.current.b) == 0 { + if !d.syncStream.inFrame { + d.frame.history.reset() + d.current.err = d.frame.reset(&d.syncStream.br) + if d.current.err != nil { + return false + } + if d.frame.DictionaryID != nil { + dict, ok := d.dicts[*d.frame.DictionaryID] + if !ok { + d.current.err = ErrUnknownDictionary + return false + } else { + d.frame.history.setDict(&dict) + } + } + if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + d.current.err = ErrDecoderSizeExceeded + return false + } + + d.syncStream.decodedFrame = 0 + d.syncStream.inFrame = true + } + d.current.err = d.frame.next(d.current.d) + if d.current.err != nil { + return false + } + d.frame.history.ensureBlock() + if debugDecoder { + println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame) + } + histBefore := len(d.frame.history.b) + d.current.err = d.current.d.decodeBuf(&d.frame.history) + + if d.current.err != nil { + println("error after:", d.current.err) + return false + } + d.current.b = d.frame.history.b[histBefore:] + if debugDecoder { + println("history after:", len(d.frame.history.b)) + } + + // Check frame size (before CRC) + d.syncStream.decodedFrame += uint64(len(d.current.b)) + if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeExceeded + return false + } + + // Check FCS + if d.current.d.Last && d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame != d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeMismatch + return false + } + + // Update/Check CRC + if d.frame.HasCheckSum { + d.frame.crc.Write(d.current.b) + if d.current.d.Last { + d.current.err = d.frame.checkCRC() + if d.current.err != nil { + println("CRC error:", d.current.err) + return false + } + } + } + d.syncStream.inFrame = !d.current.d.Last + } + return true +} + func (d *Decoder) stashDecoder() { if d.current.d != nil { d.decoders <- d.current.d diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 61614321ad..6f377bc34f 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -460,7 +460,7 @@ func TestNewReaderRead(t *testing.T) { } func TestNewDecoderBig(t *testing.T) { - if testing.Short() { + if testing.Short() || isRaceTest { t.SkipNow() } file := "testdata/zstd-10kfiles.zip" @@ -480,7 +480,7 @@ func TestNewDecoderBig(t *testing.T) { } func TestNewDecoderBigFile(t *testing.T) { - if testing.Short() { + if testing.Short() || isRaceTest { t.SkipNow() } file := "testdata/enwik9.zst" @@ -1430,7 +1430,7 @@ func testDecoderDecodeAll(t *testing.T, fn string, dec *Decoder) { wg.Add(1) t.Run("DecodeAll-"+tt.Name, func(t *testing.T) { defer wg.Done() - //t.Parallel() + t.Parallel() r, err := tt.Open() if err != nil { t.Fatal(err) From d3974d36af45956df4ec3df5a0b449290bc8aa26 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 14:51:07 +0100 Subject: [PATCH 21/31] Don't read sent error. --- zstd/decoder.go | 3 ++- zstd/decoder_test.go | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 52313d2b4d..693f38c390 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -884,12 +884,13 @@ decodeStream: println("found crc to check:", dec.checkCRC) } } + err = dec.err select { case <-ctx.Done(): break decodeStream case seqPrepare <- dec: } - if dec.err != nil { + if err != nil { break decodeStream } if dec.Last { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 6f377bc34f..98b84f197c 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1683,6 +1683,9 @@ func TestResetNil(t *testing.T) { } func timeout(after time.Duration) (cancel func()) { + if isRaceTest { + return func() {} + } c := time.After(after) cc := make(chan struct{}) go func() { From 133d52c758c35ec9bdbb4030dfd4b4832dfc5fb8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 17:30:31 +0100 Subject: [PATCH 22/31] Fix consistent error reporting and dict inits. --- zstd/bitreader.go | 4 +++ zstd/decoder.go | 12 ++------ zstd/decoder_test.go | 55 +++++++++++++++++++++++++++++++---- zstd/dict_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++ zstd/history.go | 1 + zstd/seqdec.go | 3 -- zstd/testdata/bad.zip | Bin 5089 -> 5277 bytes 7 files changed, 123 insertions(+), 18 deletions(-) diff --git a/zstd/bitreader.go b/zstd/bitreader.go index 753d17df63..d7cd15ba29 100644 --- a/zstd/bitreader.go +++ b/zstd/bitreader.go @@ -7,6 +7,7 @@ package zstd import ( "encoding/binary" "errors" + "fmt" "io" "math/bits" ) @@ -132,6 +133,9 @@ func (b *bitReader) remain() uint { func (b *bitReader) close() error { // Release reference. b.in = nil + if !b.finished() { + return fmt.Errorf("%d extra bits on block, should be 0", b.remain()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } diff --git a/zstd/decoder.go b/zstd/decoder.go index 693f38c390..4444a43cca 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -838,10 +838,7 @@ decodeStream: case <-ctx.Done(): case dec := <-d.decoders: dec.sendErr(err) - select { - case seqPrepare <- dec: - case <-ctx.Done(): - } + seqPrepare <- dec } break decodeStream } @@ -853,6 +850,7 @@ decodeStream: case <-ctx.Done(): break decodeStream case dec = <-d.decoders: + // Once we have a decoder, we MUST return it. } err := frame.next(dec) if !historySent { @@ -885,11 +883,7 @@ decodeStream: } } err = dec.err - select { - case <-ctx.Done(): - break decodeStream - case seqPrepare <- dec: - } + seqPrepare <- dec if err != nil { break decodeStream } diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 98b84f197c..d4ddd212ca 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -202,7 +202,6 @@ func TestErrorWriter(t *testing.T) { func TestNewDecoder(t *testing.T) { for _, n := range []int{1, 4} { t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) { - defer timeout(60 * time.Second)() newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(n)) } @@ -393,7 +392,6 @@ func TestNewDecoderFrameSize(t *testing.T) { func TestNewDecoderGood(t *testing.T) { for _, n := range []int{1, 4} { t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) { - defer timeout(30 * time.Second)() newFn := func() (*Decoder, error) { return NewReader(nil, WithDecoderConcurrency(n)) } @@ -1017,24 +1015,48 @@ func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { continue } t.Run("Reader-"+tt.Name, func(t *testing.T) { + defer timeout(10 * time.Second)() r, err := tt.Open() if err != nil { t.Error(err) return } - defer r.Close() - err = dec.Reset(r) + data, err := ioutil.ReadAll(r) + r.Close() if err != nil { t.Error(err) return } - got, err := ioutil.ReadAll(dec) + err = dec.Reset(ioutil.NopCloser(bytes.NewBuffer(data))) + if err != nil { + t.Error(err) + return + } + var got []byte + var gotError error + var wg sync.WaitGroup + wg.Add(1) + go func() { + got, gotError = ioutil.ReadAll(dec) + wg.Done() + }() + + // This decode should not interfere with the stream... + gotDecAll, err := dec.DecodeAll(data, nil) if err != nil { t.Error(err) if err != ErrCRCMismatch { return } } + wg.Wait() + if gotError != nil { + t.Error(err) + if err != ErrCRCMismatch { + return + } + } + wantB := want[tt.Name] if !bytes.Equal(wantB, got) { if len(wantB)+len(got) < 1000 { @@ -1054,6 +1076,23 @@ func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { t.Error("Output mismatch") return } + if !bytes.Equal(wantB, gotDecAll) { + if len(wantB)+len(got) < 1000 { + t.Logf(" got: %v\nwant: %v", got, wantB) + } else { + fileName, _ := filepath.Abs(filepath.Join("testdata", t.Name()+"-want.bin")) + _ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm) + err := ioutil.WriteFile(fileName, wantB, os.ModePerm) + t.Log("Wrote file", fileName, err) + + fileName, _ = filepath.Abs(filepath.Join("testdata", t.Name()+"-got.bin")) + _ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm) + err = ioutil.WriteFile(fileName, got, os.ModePerm) + t.Log("Wrote file", fileName, err) + } + t.Logf("Length, want: %d, got: %d", len(wantB), len(got)) + t.Error("DecodeAll Output mismatch") + } t.Log(len(got), "bytes returned, matches input, ok!") }) } @@ -1102,7 +1141,11 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) } got, err := ioutil.ReadAll(dec) if err == nil { - t.Error("Did not get expected error, got ", len(got), "bytes") + want := errMap[tt.Name] + if want == "" { + want = "" + } + t.Error("Did not get expected error", want, "- got ", len(got), "bytes") return } if errMap[tt.Name] == "" { diff --git a/zstd/dict_test.go b/zstd/dict_test.go index bbc6b8d343..424107052a 100644 --- a/zstd/dict_test.go +++ b/zstd/dict_test.go @@ -411,3 +411,69 @@ func TestDecoder_MoreDicts(t *testing.T) { }) } } + +func TestDecoder_MoreDicts2(t *testing.T) { + // All files have CRC + // https://files.klauspost.com/compress/zstd-dict-tests.zip + fn := "testdata/zstd-dict-tests.zip" + data, err := ioutil.ReadFile(fn) + if err != nil { + t.Skip("extended dict test not found.") + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + t.Fatal(err) + } + + var dicts [][]byte + for _, tt := range zr.File { + if !strings.HasSuffix(tt.Name, ".dict") { + continue + } + func() { + r, err := tt.Open() + if err != nil { + t.Fatal(err) + } + defer r.Close() + in, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + dicts = append(dicts, in) + }() + } + dec, err := NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...)) + if err != nil { + t.Fatal(err) + return + } + defer dec.Close() + for i, tt := range zr.File { + if !strings.HasSuffix(tt.Name, ".zst") { + continue + } + if testing.Short() && i > 50 { + continue + } + t.Run("decodeall-"+tt.Name, func(t *testing.T) { + r, err := tt.Open() + if err != nil { + t.Fatal(err) + } + defer r.Close() + in, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + got, err := dec.DecodeAll(in, nil) + if err != nil { + t.Fatal(err) + } + _, err = dec.DecodeAll(in, got[:0]) + if err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/zstd/history.go b/zstd/history.go index c96687aada..28b40153cc 100644 --- a/zstd/history.go +++ b/zstd/history.go @@ -65,6 +65,7 @@ func (h *history) setDict(dict *dict) { h.decoders.litLengths = dict.llDec h.decoders.offsets = dict.ofDec h.decoders.matchLengths = dict.mlDec + h.decoders.dict = dict.content h.recentOffsets = dict.offsets h.huffTree = dict.litEnc } diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 5dcb8dbedc..9c0fa7eddc 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -230,9 +230,6 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { if s.seqSize > maxBlockSize { return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) } - if !br.finished() { - return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) - } err := br.close() if err != nil { printf("Closing sequences: %v, %+v\n", err, *br) diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index e8ffc3d76ad2baede9ab8484d14ffd5aede39571..fdefd6d2140ce4db371fa34464bcdee168bc5662 100644 GIT binary patch delta 224 zcmaE;K38*tC|3eAiwFY~0|$dxM_5P*0|Rr*RErd23j=c#i&V2@12ao=y{h7p%Qsg2pMG$`^hU?S3@kGw8;&S@`<1#G xGfeL1lG&`!HIFlokx7IBw_QN%fItC=g}F$xIl!Bh4Ww5E2!()pOM#ey0RUC3Fx~(F delta 37 ocmbQM`A~g>DA(lATvD4Sa?RruVP#_g0%aiF!O6fN%mv~B0Ll~xfB*mh From 2b89e67f8036f324b4165e7152b9588079a227c1 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Feb 2022 17:50:22 +0100 Subject: [PATCH 23/31] Fix error message --- zstd/seqdec.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 9c0fa7eddc..e367281465 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -194,7 +194,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { } litRemain -= ll if litRemain < 0 { - return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain) + return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain+ll) } seqs[i] = seqVals{ ll: ll, From 606373a741be6c241de520b364bf85e84be64dc4 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 13:15:05 +0100 Subject: [PATCH 24/31] Check if huff0 X4 blocks match size exactly. --- huff0/bitreader.go | 121 ++++++------------------------------------ huff0/decompress.go | 63 +++++++++++++--------- zstd/blockdec.go | 7 ++- zstd/decoder.go | 2 +- zstd/framedec.go | 2 +- zstd/fuzz.go | 11 ++++ zstd/fuzz_none.go | 11 ++++ zstd/testdata/bad.zip | Bin 5277 -> 6130 bytes 8 files changed, 83 insertions(+), 134 deletions(-) create mode 100644 zstd/fuzz.go create mode 100644 zstd/fuzz_none.go diff --git a/huff0/bitreader.go b/huff0/bitreader.go index a4979e8868..03562db16f 100644 --- a/huff0/bitreader.go +++ b/huff0/bitreader.go @@ -8,115 +8,10 @@ package huff0 import ( "encoding/binary" "errors" + "fmt" "io" ) -// bitReader reads a bitstream in reverse. -// The last set bit indicates the start of the stream and is used -// for aligning the input. -type bitReader struct { - in []byte - off uint // next byte to read is at in[off - 1] - value uint64 - bitsRead uint8 -} - -// init initializes and resets the bit reader. -func (b *bitReader) init(in []byte) error { - if len(in) < 1 { - return errors.New("corrupt stream: too short") - } - b.in = in - b.off = uint(len(in)) - // The highest bit of the last byte indicates where to start - v := in[len(in)-1] - if v == 0 { - return errors.New("corrupt stream, did not find end of stream") - } - b.bitsRead = 64 - b.value = 0 - if len(in) >= 8 { - b.fillFastStart() - } else { - b.fill() - b.fill() - } - b.bitsRead += 8 - uint8(highBit32(uint32(v))) - return nil -} - -// peekBitsFast requires that at least one bit is requested every time. -// There are no checks if the buffer is filled. -func (b *bitReader) peekBitsFast(n uint8) uint16 { - const regMask = 64 - 1 - v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) - return v -} - -// fillFast() will make sure at least 32 bits are available. -// There must be at least 4 bytes available. -func (b *bitReader) fillFast() { - if b.bitsRead < 32 { - return - } - - // 2 bounds checks. - v := b.in[b.off-4 : b.off] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - b.value = (b.value << 32) | uint64(low) - b.bitsRead -= 32 - b.off -= 4 -} - -func (b *bitReader) advance(n uint8) { - b.bitsRead += n -} - -// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read. -func (b *bitReader) fillFastStart() { - // Do single re-slice to avoid bounds checks. - b.value = binary.LittleEndian.Uint64(b.in[b.off-8:]) - b.bitsRead = 0 - b.off -= 8 -} - -// fill() will make sure at least 32 bits are available. -func (b *bitReader) fill() { - if b.bitsRead < 32 { - return - } - if b.off > 4 { - v := b.in[b.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - b.value = (b.value << 32) | uint64(low) - b.bitsRead -= 32 - b.off -= 4 - return - } - for b.off > 0 { - b.value = (b.value << 8) | uint64(b.in[b.off-1]) - b.bitsRead -= 8 - b.off-- - } -} - -// finished returns true if all bits have been read from the bit stream. -func (b *bitReader) finished() bool { - return b.off == 0 && b.bitsRead >= 64 -} - -// close the bitstream and returns an error if out-of-buffer reads occurred. -func (b *bitReader) close() error { - // Release reference. - b.in = nil - if b.bitsRead > 64 { - return io.ErrUnexpectedEOF - } - return nil -} - // bitReader reads a bitstream in reverse. // The last set bit indicates the start of the stream and is used // for aligning the input. @@ -213,10 +108,17 @@ func (b *bitReaderBytes) finished() bool { return b.off == 0 && b.bitsRead >= 64 } +func (b *bitReaderBytes) remaining() uint { + return b.off*8 + uint(64-b.bitsRead) +} + // close the bitstream and returns an error if out-of-buffer reads occurred. func (b *bitReaderBytes) close() error { // Release reference. b.in = nil + if b.remaining() > 0 { + return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } @@ -318,10 +220,17 @@ func (b *bitReaderShifted) finished() bool { return b.off == 0 && b.bitsRead >= 64 } +func (b *bitReaderShifted) remaining() uint { + return b.off*8 + uint(64-b.bitsRead) +} + // close the bitstream and returns an error if out-of-buffer reads occurred. func (b *bitReaderShifted) close() error { // Release reference. b.in = nil + if b.remaining() > 0 { + return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } diff --git a/huff0/decompress.go b/huff0/decompress.go index 2668b64d37..3ae7d46771 100644 --- a/huff0/decompress.go +++ b/huff0/decompress.go @@ -741,6 +741,7 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { } var br [4]bitReaderShifted + // Decode "jump table" start := 6 for i := 0; i < 3; i++ { length := int(src[i*2]) | (int(src[i*2+1]) << 8) @@ -865,30 +866,18 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { } // Decode remaining. + remainBytes := dstEvery - (decoded / 4) for i := range br { offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } br := &br[i] - bitsLeft := br.off*8 + uint(64-br.bitsRead) + bitsLeft := br.remaining() for bitsLeft > 0 { br.fill() - if false && br.bitsRead >= 32 { - if br.off >= 4 { - v := br.in[br.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - br.value = (br.value << 32) | uint64(low) - br.bitsRead -= 32 - br.off -= 4 - } else { - for br.off > 0 { - br.value = (br.value << 8) | uint64(br.in[br.off-1]) - br.bitsRead -= 8 - br.off-- - } - } - } - // end inline... - if offset >= len(out) { + if offset >= endsAt { d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 4") } @@ -902,6 +891,10 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { out[offset] = uint8(v >> 8) offset++ } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } decoded += offset - dstEvery*i err = br.close() if err != nil { @@ -1091,10 +1084,16 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { } // Decode remaining. + // Decode remaining. + remainBytes := dstEvery - (decoded / 4) for i := range br { offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } br := &br[i] - bitsLeft := int(br.off*8) + int(64-br.bitsRead) + bitsLeft := br.remaining() for bitsLeft > 0 { if br.finished() { d.bufs.Put(buf) @@ -1117,7 +1116,7 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { } } // end inline... - if offset >= len(out) { + if offset >= endsAt { d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 4") } @@ -1126,10 +1125,14 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { v := single[uint8(br.value>>shift)].entry nBits := uint8(v) br.advance(nBits) - bitsLeft -= int(nBits) + bitsLeft -= uint(nBits) out[offset] = uint8(v >> 8) offset++ } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } decoded += offset - dstEvery*i err = br.close() if err != nil { @@ -1315,10 +1318,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { } // Decode remaining. + remainBytes := dstEvery - (decoded / 4) for i := range br { offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } br := &br[i] - bitsLeft := int(br.off*8) + int(64-br.bitsRead) + bitsLeft := br.remaining() for bitsLeft > 0 { if br.finished() { d.bufs.Put(buf) @@ -1341,7 +1349,7 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { } } // end inline... - if offset >= len(out) { + if offset >= endsAt { d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 4") } @@ -1350,10 +1358,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { v := single[br.peekByteFast()].entry nBits := uint8(v) br.advance(nBits) - bitsLeft -= int(nBits) + bitsLeft -= uint(nBits) out[offset] = uint8(v >> 8) offset++ } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } + decoded += offset - dstEvery*i err = br.close() if err != nil { diff --git a/zstd/blockdec.go b/zstd/blockdec.go index f4bde47de1..e5a38d1408 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -335,6 +335,9 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err if debugDecoder { println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams) } + if litRegenSize > int(b.WindowSize) || litRegenSize > maxCompressedBlockSize { + return in, ErrWindowSizeExceeded + } switch litType { case literalsBlockRaw: @@ -396,6 +399,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err } var err error // Use our out buffer. + huff.MaxDecodedSize = maxCompressedBlockSize if fourStreams { literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) } else { @@ -422,7 +426,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err if b.lowMem { b.literalBuf = make([]byte, 0, litRegenSize) } else { - b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) + b.literalBuf = make([]byte, 0, maxCompressedBlockSize) } } huff := hist.huffTree @@ -439,6 +443,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err return in, err } hist.huffTree = huff + huff.MaxDecodedSize = maxCompressedBlockSize // Use our out buffer. if fourStreams { literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) diff --git a/zstd/decoder.go b/zstd/decoder.go index 4444a43cca..9cbab94d8b 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -450,7 +450,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { got := d.current.crc.Sum64() var tmp [4]byte binary.LittleEndian.PutUint32(tmp[:], uint32(got)) - if !bytes.Equal(tmp[:], next.d.checkCRC) { + if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC { if debugDecoder { println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)") } diff --git a/zstd/framedec.go b/zstd/framedec.go index 51a8773622..29c3176b05 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -305,7 +305,7 @@ func (d *frameDec) checkCRC() error { return err } - if !bytes.Equal(tmp[:], want) { + if !bytes.Equal(tmp[:], want) && !ignoreCRC { if debugDecoder { println("CRC Check Failed:", tmp[:], "!=", want) } diff --git a/zstd/fuzz.go b/zstd/fuzz.go new file mode 100644 index 0000000000..fda8a74228 --- /dev/null +++ b/zstd/fuzz.go @@ -0,0 +1,11 @@ +//go:build gofuzz +// +build gofuzz + +// Copyright 2019+ Klaus Post. All rights reserved. +// License information can be found in the LICENSE file. +// Based on work by Yann Collet, released under BSD License. + +package zstd + +// ignoreCRC can be used for fuzz testing to ignore CRC values... +const ignoreCRC = true diff --git a/zstd/fuzz_none.go b/zstd/fuzz_none.go new file mode 100644 index 0000000000..0515b201cc --- /dev/null +++ b/zstd/fuzz_none.go @@ -0,0 +1,11 @@ +//go:build !gofuzz +// +build !gofuzz + +// Copyright 2019+ Klaus Post. All rights reserved. +// License information can be found in the LICENSE file. +// Based on work by Yann Collet, released under BSD License. + +package zstd + +// ignoreCRC can be used for fuzz testing to ignore CRC values... +const ignoreCRC = false diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index fdefd6d2140ce4db371fa34464bcdee168bc5662..5d96cdd98fc8b3dd4fcaeab3f313fa2dde96dba9 100644 GIT binary patch delta 895 zcmbQM`AL66IafV1iwFY~0|!I?maq_&obPuxGBGf$0b(5n1_ska149$@loYdM6Jv`+ z3k$=P#8k_~Br`*U#6)v*izHLiG!vkRp5XhqWiBpU6Jx zH^+0up;OCS~0g_y7N+oeo^$O4Og8 z8>w4s{)lhW$NxL__B8-)j^+!K)E2!L*qLxpGNzYp@}h+f$ICTbbdSDw%zo&)$hkL^ zq5W0C+@e#`o;mdfD=w;}2u+=slxm@H=EM}M1^km#@2Tdk_;T{egP5q*1&v%@uMSHz zYc87c?Za$6)qUbGWiB6i$u&>$Sp>J}thwfTT|1jMzDaHqmYl&SUi+v1&LN{kORO9Y zUX{MX!obhRAj2TZAO(b*f6e}qAMUT7Csg%8D@a$WJ^PQ)`xK{oS10W?=S++x%4Y8W zddYeXXYqdLR56+Ned-ERnZ%QpyI!7~bEtRw-0=1_3oH+3tbXR8*Qnjib~{ll;p@4T z!c~gTiz7ZH@U?y7oqRy~hZXbL3tL6FIvSdU+&6}b+>)EX$10-#!B#oukyjMU?u!L~ zuVvdrR2`V8BjE6WnW5NpO1*sEEA>6|EamKedtb=ux#6w5o9;Cxg%=-o{UP?*a)!F(Y{TclpAXsT zuH%}2!AeY_%KEj2wNBWrSxqy)@Ri1oECU1(vk#(3qOdqxL*Il3=7F12~~r>y+O zyj0EU%S*Pu^+~%H)HnN4NYUd1`m&6?3mTY&eN4;>Gj}ozC^!l&sTJ0_biwsStGR2g zWAeqa3yn_{1%-tgr{%LIi@p(FDaxZ2F!>Xg^yZ0NKRELknM8n2!Ic$2DVl*n0f Date: Wed, 23 Feb 2022 14:51:07 +0100 Subject: [PATCH 25/31] Fix decoder leakage. --- zstd/decoder.go | 18 ++++++++++-------- zstd/decoder_test.go | 1 + 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 9cbab94d8b..00f6ddd646 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -406,14 +406,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) { } //ASYNC: - if d.current.d != nil { - if debugDecoder { - printf("re-adding current decoder %p", d.current.d) - } - d.decoders <- d.current.d - d.current.d = nil - } - + d.stashDecoder() if blocking { d.current.decodeOutput, ok = <-d.current.output } else { @@ -550,6 +543,9 @@ func (d *Decoder) nextBlockSync() (ok bool) { func (d *Decoder) stashDecoder() { if d.current.d != nil { + if debugDecoder { + printf("re-adding current decoder %p", d.current.d) + } d.decoders <- d.current.d d.current.d = nil } @@ -645,6 +641,9 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch var hasErr bool for block := range seqPrepare { if hasErr { + if block != nil { + d.decoders <- block + } continue } if block.async.newHist != nil { @@ -682,6 +681,9 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch for block := range seqDecode { if hasErr { + if block != nil { + d.decoders <- block + } continue } if block.async.newHist != nil { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index d4ddd212ca..946e31ec35 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1128,6 +1128,7 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) defer dec.Close() for _, tt := range zr.File { t.Run(tt.Name, func(t *testing.T) { + defer timeout(10 * time.Second)() r, err := tt.Open() if err != nil { t.Error(err) From d1f91e2e0ff39a6673f06c0bdfb8c5fb014d92d8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 14:54:40 +0100 Subject: [PATCH 26/31] Forward blocks, so we don't risk run-away decoding. --- zstd/decoder.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 00f6ddd646..d164ead4aa 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -642,7 +642,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch for block := range seqPrepare { if hasErr { if block != nil { - d.decoders <- block + seqDecode <- block } continue } @@ -682,7 +682,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch for block := range seqDecode { if hasErr { if block != nil { - d.decoders <- block + seqExecute <- block } continue } From e7906eb2df0c6457df8b61702a744497a9001f48 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 15:28:25 +0100 Subject: [PATCH 27/31] Fix test race --- zstd/decoder.go | 2 +- zstd/decoder_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index d164ead4aa..9d1ff3d592 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -299,7 +299,7 @@ func (d *Decoder) WriteTo(w io.Writer) (int64, error) { // DecodeAll can be used concurrently. // The Decoder concurrency limits will be respected. func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { - if d.current.err == ErrDecoderClosed { + if d.decoders == nil { return dst, ErrDecoderClosed } diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 946e31ec35..9bba2102a7 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1046,6 +1046,7 @@ func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) { if err != nil { t.Error(err) if err != ErrCRCMismatch { + wg.Wait() return } } From 7a09f3db5880c8d679903ab058e935b4b24f2059 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 15:46:44 +0100 Subject: [PATCH 28/31] Save last before sending. --- zstd/decoder.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 9d1ff3d592..427ca0a01c 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -885,11 +885,12 @@ decodeStream: } } err = dec.err + last := dec.Last seqPrepare <- dec if err != nil { break decodeStream } - if dec.Last { + if last { break } } From f5e0961a0b8a3f562d8d790b7612c92ae4e6fa8c Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 17:13:31 +0100 Subject: [PATCH 29/31] Reuse history between async calls. --- zstd/decoder.go | 20 ++++++++++++++-- zstd/decoder_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 427ca0a01c..e167bc720d 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -635,6 +635,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch var seqPrepare = make(chan *blockDec, d.o.concurrent) var seqDecode = make(chan *blockDec, d.o.concurrent) var seqExecute = make(chan *blockDec, d.o.concurrent) + // Async 1: Prepare blocks... go func() { var hist history @@ -674,6 +675,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } close(seqDecode) }() + // Async 2: Decode sequences... go func() { var hist history @@ -721,7 +723,12 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } close(seqExecute) }() + + var wg sync.WaitGroup + wg.Add(1) + // Async 3: Execute sequences... + frameHistCache := d.frame.history.b go func() { var hist history var decodedFrame uint64 @@ -743,9 +750,14 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch if block.async.newHist.dict != nil { hist.setDict(block.async.newHist.dict) } + if cap(hist.b) < hist.allocFrameBuffer { - hist.b = make([]byte, 0, hist.allocFrameBuffer) - println("Alloc history sized", hist.allocFrameBuffer) + if cap(frameHistCache) >= hist.allocFrameBuffer { + hist.b = frameHistCache + } else { + hist.b = make([]byte, 0, hist.allocFrameBuffer) + println("Alloc history sized", hist.allocFrameBuffer) + } } hist.b = hist.b[:0] fcs = block.async.fcs @@ -807,6 +819,8 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch output <- do } close(output) + frameHistCache = hist.b + wg.Done() if debugDecoder { println("decoder goroutines finished") } @@ -896,4 +910,6 @@ decodeStream: } } close(seqPrepare) + wg.Wait() + d.frame.history.b = frameHistCache } diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 9bba2102a7..5eeebada6c 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1283,6 +1283,9 @@ func BenchmarkDecoder_DecodeAllFiles(b *testing.B) { b.Error(err) } for i := SpeedFastest; i <= SpeedBestCompression; i++ { + if testing.Short() && i > SpeedFastest { + break + } b.Run(i.String(), func(b *testing.B) { enc, err := NewWriter(nil, WithEncoderLevel(i), WithSingleSegment(true)) if err != nil { @@ -1317,6 +1320,59 @@ func BenchmarkDecoder_DecodeAllFiles(b *testing.B) { }) } +func BenchmarkDecoder_DecodeAllFilesP(b *testing.B) { + filepath.Walk("../testdata/", func(path string, info os.FileInfo, err error) error { + if info.IsDir() || info.Size() < 100 { + return nil + } + b.Run(filepath.Base(path), func(b *testing.B) { + raw, err := ioutil.ReadFile(path) + if err != nil { + b.Error(err) + } + for i := SpeedFastest; i <= SpeedBestCompression; i++ { + if testing.Short() && i > SpeedFastest { + break + } + b.Run(i.String(), func(b *testing.B) { + enc, err := NewWriter(nil, WithEncoderLevel(i), WithSingleSegment(true)) + if err != nil { + b.Error(err) + } + encoded := enc.EncodeAll(raw, nil) + if err != nil { + b.Error(err) + } + dec, err := NewReader(nil, WithDecoderConcurrency(0)) + if err != nil { + b.Error(err) + } + _, err = dec.DecodeAll(encoded, nil) + if err != nil { + b.Error(err) + } + + b.SetBytes(int64(len(raw))) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + buf := make([]byte, len(raw)) + var err error + for pb.Next() { + buf, err = dec.DecodeAll(encoded, buf[:0]) + if err != nil { + b.Error(err) + } + } + }) + b.ReportMetric(100*float64(len(encoded))/float64(len(raw)), "pct") + }) + } + }) + return nil + }) +} + func BenchmarkDecoder_DecodeAllParallel(b *testing.B) { fn := "testdata/benchdecoder.zip" data, err := ioutil.ReadFile(fn) From 6cabc282032a273730b6bb34f4d1bbc5830c0857 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 23 Feb 2022 17:20:11 +0100 Subject: [PATCH 30/31] Protect local frame. --- zstd/decoder.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index e167bc720d..310edbec25 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -209,9 +209,6 @@ func (d *Decoder) Reset(r io.Reader) error { } return nil } - if d.frame == nil { - d.frame = newFrameDec(d.o) - } // Remove current block. d.stashDecoder() d.current.decodeOutput = decodeOutput{} @@ -219,14 +216,20 @@ func (d *Decoder) Reset(r io.Reader) error { d.current.flushed = false d.current.d = nil + // Ensure no-one else is still running... + d.streamWg.Wait() + if d.frame == nil { + d.frame = newFrameDec(d.o) + } + if d.o.concurrent == 1 { return d.startSyncDecoder(r) } d.current.output = make(chan decodeOutput, d.o.concurrent) - d.streamWg.Add(1) ctx, cancel := context.WithCancel(context.Background()) d.current.cancel = cancel + d.streamWg.Add(1) go d.startStreamDecoder(ctx, r, d.current.output) return nil From f20d5639ccc918953d8bf0a586d437ecd610be13 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 24 Feb 2022 11:46:26 +0100 Subject: [PATCH 31/31] Clarify error msg --- zstd/testdata/bad.zip | Bin 6130 -> 6116 bytes zstd/zstd.go | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index 5d96cdd98fc8b3dd4fcaeab3f313fa2dde96dba9..2c283f197e6a1a9baf975fd49e65d32b1ca1b567 100644 GIT binary patch delta 497 zcmeyQ|3rTS8yhzRgZR4Xe%bbKW{OT0WHUF^*sA|G&VYfzz<`5+VJ#yAgQHGjRDd@# ziwFY~0|&!buh0+;myMB13=9k=K&-!a!-r&@Ewz-O(l@aJZ z(Uj1TscPxRxF>gUm~(^F88R3s8yFaDKF6^hq#%DwScpo__q%H*cW{|+KEgGD1I1VI z!ja+#=K|df1`0qjcJg^)OGPAC0>xN>SOUmDPjpU0PbF(m=F-dR55^6 ziL+0(VAq?xNkogWc=BD5NTy@#lV1V_(nPfwr6*4m1=$1l2~aoCiD)JiOnxmI$#j%s z@+%I#$t7Z1jPjG`i$yZdo%~ZQl8K9JvObrdPJlNflL!MSDv-PlWp={ delta 488 zcmaE&|4DxX8yhzVgZR4Xe%bbKW{OW1WHZ;gd}Gyr^K%W22M;tfHzw*TGdC~xOA=uS z@MdNaVPIn5VEF148lvH{F;aPQ0h_?b8898!^#2!w=&{JP{G#y0s5o4F*& zL~$h94+)_m{X(luL?&yon@fOXx4Rzuwa(|!!?ucJ(sQiNgVj!6B_cKXl8Daaz3jV1 zI2c4zLPMsir61#EU|`?|VjYIb_eAx0LAtL`IOuTmzySt^%||)bGXg!7za=b0CFlFy zjg#xS%vl-NFa`hJyp3xD2Z~<nB% z3<8V{3PAGAWN8sgpos|I#EXbB%1ka5iDayrd>%-iovbMu$)vPP!j>!)>^d_f?X)$U{o+uW{xNP!kASpcARUG6txEGli7}fwW T%xw@ON+-`3w`VH=g%tw;-MNV5 diff --git a/zstd/zstd.go b/zstd/zstd.go index 448c1c6ca7..0b0c2571dd 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -77,7 +77,7 @@ var ( // ErrFrameSizeMismatch is returned if the stated frame size does not match the expected size. // This is only returned if SingleSegment is specified on the frame. - ErrFrameSizeMismatch = errors.New("frame size does not match on stream size") + ErrFrameSizeMismatch = errors.New("frame size does not match size on stream") // ErrCRCMismatch is returned if CRC mismatches. ErrCRCMismatch = errors.New("CRC check failed")