From 5a3a4a965cc6dc28df4962646ef8fba7354cc12e Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 29 Jul 2022 03:14:46 -0700 Subject: [PATCH] zstd: Add DecodeAllCapLimit (#649) * zstd: Add DecodeAllCapLimit WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, or any size set in WithDecoderMaxMemory. This can be used to limit decoding to a specific maximum output size. Disabled by default. Fixes #647 --- zstd/decoder.go | 21 ++++++++++++++-- zstd/decoder_options.go | 12 +++++++++ zstd/decoder_test.go | 54 +++++++++++++++++++++++++++++++++++++++++ zstd/framedec.go | 23 +++++++++++++++--- 4 files changed, 105 insertions(+), 5 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index d212f4737f..6104eb7936 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { // Grab a block decoder and frame decoder. block := <-d.decoders frame := block.localFrame + initialSize := len(dst) defer func() { if debugDecoder { printf("re-adding decoder: %p", block) @@ -354,7 +355,16 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { return dst, ErrWindowSizeExceeded } if frame.FrameContentSize != fcsUnknown { - if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) { + if debugDecoder { + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst)) + } + return dst, ErrDecoderSizeExceeded + } + if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) { + if debugDecoder { + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst)) + } return dst, ErrDecoderSizeExceeded } if cap(dst)-len(dst) < int(frame.FrameContentSize) { @@ -364,7 +374,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } } - if cap(dst) == 0 { + if cap(dst) == 0 && !d.o.limitToCap { // Allocate len(input) * 2 by default if nothing is provided // and we didn't get frame content size. size := len(input) * 2 @@ -382,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { if err != nil { return dst, err } + if uint64(len(dst)-initialSize) > d.o.maxDecodedSize { + return dst, ErrDecoderSizeExceeded + } if len(frame.bBuf) == 0 { if debugDecoder { println("frame dbuf empty") @@ -852,6 +865,10 @@ decodeStream: } } if err == nil && d.frame.WindowSize > d.o.maxWindowSize { + if debugDecoder { + println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize) + } + err = ErrDecoderSizeExceeded } if err != nil { diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index c70e6fa0f7..666c2715fe 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -20,6 +20,7 @@ type decoderOptions struct { maxWindowSize uint64 dicts []dict ignoreChecksum bool + limitToCap bool } func (o *decoderOptions) setDefault() { @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption { } } +// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, +// or any size set in WithDecoderMaxMemory. +// This can be used to limit decoding to a specific maximum output size. +// Disabled by default. +func WithDecodeAllCapLimit(b bool) DOption { + return func(o *decoderOptions) error { + o.limitToCap = b + return nil + } +} + // IgnoreChecksum allows to forcibly ignore checksum checking. func IgnoreChecksum(b bool) DOption { return func(o *decoderOptions) error { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 30ec9bfee4..971b5cdb42 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "runtime" + "strconv" "strings" "sync" "testing" @@ -1900,3 +1901,56 @@ func timeout(after time.Duration) (cancel func()) { close(cc) } } + +func TestWithDecodeAllCapLimit(t *testing.T) { + var encs []*Encoder + var decs []*Decoder + addEnc := func(e *Encoder, _ error) { + encs = append(encs, e) + } + addDec := func(d *Decoder, _ error) { + decs = append(decs, d) + } + addEnc(NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithEncoderConcurrency(1), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithZeroFrames(false), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithWindowSize(128<<10))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderConcurrency(1))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderLowmem(true))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxWindow(128<<10))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxMemory(1<<20))) + for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 { + sz := sz + t.Run(strconv.Itoa(sz), func(t *testing.T) { + t.Parallel() + for ei, enc := range encs { + for di, dec := range decs { + t.Run(fmt.Sprintf("e%d:d%d", ei, di), func(t *testing.T) { + encoded := enc.EncodeAll(make([]byte, sz), nil) + for i := sz - 1; i < sz+1; i++ { + if i < 0 { + continue + } + const existinglen = 5 + got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen)) + if i < sz { + if err != ErrDecoderSizeExceeded { + t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err) + } + } else { + if err != nil { + t.Errorf("cap: %d, want %v, got %v", i, nil, err) + continue + } + if len(got) != existinglen+i { + t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got)) + } + } + } + }) + } + } + }) + } +} diff --git a/zstd/framedec.go b/zstd/framedec.go index 9568a4ba31..1559a20386 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { // Store input length, so we only check new data. crcStart := len(dst) d.history.decoders.maxSyncLen = 0 + if d.o.limitToCap { + d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst)) + } if d.FrameContentSize != fcsUnknown { - d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen { + d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + } if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize) + } return dst, ErrDecoderSizeExceeded } - if uint64(cap(dst)) < d.history.decoders.maxSyncLen { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen) + } + if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen { // Alloc for output dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) copy(dst2, dst) @@ -378,7 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { if err != nil { break } - if uint64(len(d.history.b)) > d.o.maxDecodedSize { + if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize { + println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize) + err = ErrDecoderSizeExceeded + break + } + if d.o.limitToCap && len(d.history.b) > cap(dst) { + println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst)) err = ErrDecoderSizeExceeded break }