diff --git a/zstd/decoder.go b/zstd/decoder.go index a93dfaf100..0791a88dcc 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -348,10 +348,10 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { frame.history.setDict(&dict) } - if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { return dst, ErrDecoderSizeExceeded } - if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 { + if frame.FrameContentSize < 1<<30 { // Never preallocate more than 1 GB up front. if cap(dst)-len(dst) < int(frame.FrameContentSize) { dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)) @@ -514,7 +514,7 @@ func (d *Decoder) nextBlockSync() (ok bool) { // Check frame size (before CRC) d.syncStream.decodedFrame += uint64(len(d.current.b)) - if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize { + if d.syncStream.decodedFrame > d.frame.FrameContentSize { if debugDecoder { printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) } @@ -523,7 +523,7 @@ func (d *Decoder) nextBlockSync() (ok bool) { } // Check FCS - if d.current.d.Last && d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame != d.frame.FrameContentSize { + if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize { if debugDecoder { printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) } @@ -811,11 +811,11 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch } if !hasErr { decodedFrame += uint64(len(do.b)) - if fcs > 0 && decodedFrame > fcs { + if decodedFrame > fcs { println("fcs exceeded", block.Last, fcs, decodedFrame) do.err = ErrFrameSizeExceeded hasErr = true - } else if block.Last && fcs > 0 && decodedFrame != fcs { + } else if block.Last && fcs != fcsUnknown && decodedFrame != fcs { do.err = ErrFrameSizeMismatch hasErr = true } else { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index b40d03cac0..e10d09ab28 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1142,6 +1142,10 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error) return } got, err := ioutil.ReadAll(dec) + if err == ErrCRCMismatch && !strings.Contains(tt.Name, "badsum") { + t.Error(err) + return + } if err == nil { want := errMap[tt.Name] if want == "" { diff --git a/zstd/framedec.go b/zstd/framedec.go index 29c3176b05..11089d2232 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -197,7 +197,7 @@ func (d *frameDec) reset(br byteBuffer) error { default: fcsSize = 1 << v } - d.FrameContentSize = 0 + d.FrameContentSize = fcsUnknown if fcsSize > 0 { b, err := br.readSmall(fcsSize) if err != nil { @@ -343,12 +343,7 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { err = ErrDecoderSizeExceeded break } - if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize { - println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize) - err = ErrFrameSizeExceeded - break - } - if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) > d.FrameContentSize { + if uint64(len(d.history.b)-crcStart) > d.FrameContentSize { println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize) err = ErrFrameSizeExceeded break @@ -356,13 +351,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { if dec.Last { break } - if debugDecoder && d.FrameContentSize > 0 { + if debugDecoder { println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize) } } dst = d.history.b if err == nil { - if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) != d.FrameContentSize { + if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize { err = ErrFrameSizeMismatch } else if d.HasCheckSum { var n int diff --git a/zstd/testdata/bad.zip b/zstd/testdata/bad.zip index 54f030f8ae..1339e8dcf7 100644 Binary files a/zstd/testdata/bad.zip and b/zstd/testdata/bad.zip differ diff --git a/zstd/zstd.go b/zstd/zstd.go index 52096538a1..c1c90b4a07 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -39,6 +39,9 @@ const zstdMinMatch = 3 // Reset the buffer offset when reaching this. const bufferReset = math.MaxInt32 - MaxWindowSize +// fcsUnknown is used for unknown frame content size. +const fcsUnknown = math.MaxUint64 + var ( // ErrReservedBlockType is returned when a reserved block type is found. // Typically this indicates wrong or corrupted input.