From 581d92d46d1bf279dfbd160188bc3cf6498c11fe Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Jul 2022 12:17:50 +0200 Subject: [PATCH] zstd: Add DecodeAllCapLimit WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, or any size set in WithDecoderMaxMemory. This can be used to limit decoding to a specific maximum output size. Disabled by default. Fixes #647 --- zstd/decoder.go | 31 ++++++++++++++++++++++++++++++- zstd/decoder_options.go | 12 ++++++++++++ zstd/decoder_test.go | 32 ++++++++++++++++++++++++++++++++ zstd/framedec.go | 19 +++++++++++++++++-- 4 files changed, 91 insertions(+), 3 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index d212f4737f..22419e1500 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -355,6 +355,15 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } if frame.FrameContentSize != fcsUnknown { if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + if debugDecoder { + println("decoder size exceeded:", frame.FrameContentSize, ">", d.o.maxDecodedSize-uint64(len(dst))) + } + return dst, ErrDecoderSizeExceeded + } + if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) { + if debugDecoder { + println("decoder size exceeded:", frame.FrameContentSize, ">", cap(dst)-len(dst)) + } return dst, ErrDecoderSizeExceeded } if cap(dst)-len(dst) < int(frame.FrameContentSize) { @@ -364,7 +373,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } } - if cap(dst) == 0 { + if cap(dst) == 0 && !d.o.limitToCap { // Allocate len(input) * 2 by default if nothing is provided // and we didn't get frame content size. size := len(input) * 2 @@ -493,6 +502,22 @@ func (d *Decoder) nextBlockSync() (ok bool) { d.current.err = ErrDecoderSizeExceeded return false } + if d.frame.FrameContentSize != fcsUnknown { + if !d.o.limitToCap && d.frame.FrameContentSize > d.o.maxDecodedSize { + if debugDecoder { + println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> mds", d.o.maxDecodedSize) + } + d.current.err = ErrDecoderSizeExceeded + return false + } + if d.o.limitToCap && d.frame.FrameContentSize > uint64(cap(d.frame.history.b)-len(d.frame.history.b)) { + if debugDecoder { + println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> cap", cap(d.frame.history.b)-len(d.frame.history.b)) + } + d.current.err = ErrDecoderSizeExceeded + return false + } + } d.syncStream.decodedFrame = 0 d.syncStream.inFrame = true @@ -852,6 +877,10 @@ decodeStream: } } if err == nil && d.frame.WindowSize > d.o.maxWindowSize { + if debugDecoder { + println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize) + } + err = ErrDecoderSizeExceeded } if err != nil { diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index c70e6fa0f7..666c2715fe 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -20,6 +20,7 @@ type decoderOptions struct { maxWindowSize uint64 dicts []dict ignoreChecksum bool + limitToCap bool } func (o *decoderOptions) setDefault() { @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption { } } +// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, +// or any size set in WithDecoderMaxMemory. +// This can be used to limit decoding to a specific maximum output size. +// Disabled by default. +func WithDecodeAllCapLimit(b bool) DOption { + return func(o *decoderOptions) error { + o.limitToCap = b + return nil + } +} + // IgnoreChecksum allows to forcibly ignore checksum checking. func IgnoreChecksum(b bool) DOption { return func(o *decoderOptions) error { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 30ec9bfee4..0c7e3a779e 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "runtime" + "strconv" "strings" "sync" "testing" @@ -1900,3 +1901,34 @@ func timeout(after time.Duration) (cancel func()) { close(cc) } } + +func TestWithDecodeAllCapLimit(t *testing.T) { + enc, _ := NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)) + dec, _ := NewReader(nil, WithDecodeAllCapLimit(true)) + for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 { + sz := sz + t.Run(strconv.Itoa(sz), func(t *testing.T) { + encoded := enc.EncodeAll(make([]byte, sz), nil) + for i := sz - 1; i < sz+1; i++ { + if i < 0 { + continue + } + const existinglen = 5 + got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen)) + if i < sz { + if err != ErrDecoderSizeExceeded { + t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err) + } + } else { + if err != nil { + t.Errorf("cap: %d, want %v, got %v", i, nil, err) + continue + } + if len(got) != existinglen+i { + t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got)) + } + } + } + }) + } +} diff --git a/zstd/framedec.go b/zstd/framedec.go index 9568a4ba31..0a93a2065a 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { // Store input length, so we only check new data. crcStart := len(dst) d.history.decoders.maxSyncLen = 0 + if d.o.limitToCap { + d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst)) + } if d.FrameContentSize != fcsUnknown { - d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen { + d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + } if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize) + } return dst, ErrDecoderSizeExceeded } - if uint64(cap(dst)) < d.history.decoders.maxSyncLen { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen) + } + if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen { // Alloc for output dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) copy(dst2, dst) @@ -382,6 +393,10 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { err = ErrDecoderSizeExceeded break } + if d.o.limitToCap && len(d.history.b) > cap(dst) { + err = ErrDecoderSizeExceeded + break + } if uint64(len(d.history.b)-crcStart) > d.FrameContentSize { println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize) err = ErrFrameSizeExceeded