diff --git a/zstd/blockdec.go b/zstd/blockdec.go index ef7aab25fd..50abf56569 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -525,6 +525,9 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { } if nSeqs == 0 && len(in) != 0 { // When no sequences, there should not be any more data... + if debugDecoder { + printf("prepareSequences: 0 sequences, but %d byte(s) left on stream\n", len(in)) + } return ErrUnexpectedBlockSize } @@ -645,6 +648,7 @@ func (b *blockDec) decodeSequences(hist *history) error { hist.decoders.seqSize = len(hist.decoders.literals) return nil } + hist.decoders.windowSize = hist.windowSize hist.decoders.prevOffset = hist.recentOffsets err := hist.decoders.decode(b.sequence) hist.recentOffsets = hist.decoders.prevOffset diff --git a/zstd/decoder.go b/zstd/decoder.go index 0791a88dcc..9fcdaac1dc 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -700,6 +700,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } hist.decoders = block.async.newHist.decoders hist.recentOffsets = block.async.newHist.recentOffsets + hist.windowSize = block.async.newHist.windowSize if block.async.newHist.dict != nil { hist.setDict(block.async.newHist.dict) } diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 213736ad77..819f1461b7 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -107,7 +107,10 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state s.seqSize = 0 litRemain := len(s.literals) - + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } for i := range seqs { var ll, mo, ml int if br.off > 4+((maxOffsetBits+16+16)>>3) { @@ -192,7 +195,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { } s.seqSize += ll + ml if s.seqSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) + return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) } litRemain -= ll if litRemain < 0 { @@ -230,7 +233,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error { } s.seqSize += litRemain if s.seqSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size", s.seqSize) + return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) } err := br.close() if err != nil { @@ -347,6 +350,10 @@ func (s *sequenceDecs) decodeSync(history *history) error { llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state hist := history.b[history.ignoreBuffer:] out := s.out + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } for i := seqs - 1; i >= 0; i-- { if br.overread() { @@ -426,7 +433,7 @@ func (s *sequenceDecs) decodeSync(history *history) error { } size := ll + ml + len(out) if size-startSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size", size) + return fmt.Errorf("output (%d) bigger than max block size (%d)", size, maxBlockSize) } if size > cap(out) { // Not enough size, which can happen under high volume block streaming conditions @@ -535,6 +542,11 @@ func (s *sequenceDecs) decodeSync(history *history) error { } } + // Check if space for literals + if len(s.literals)+len(s.out)-startSize > maxBlockSize { + return fmt.Errorf("output (%d) bigger than max block size (%d)", len(s.out), maxBlockSize) + } + // Add final literals s.out = append(out, s.literals...) return br.close() diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index 1f34780735..15be3cd510 100644 Binary files a/zstd/testdata/bad.zip and b/zstd/testdata/bad.zip differ