Skip to content

Commit

Permalink
zstd: Add DecodeAllCapLimit
Browse files Browse the repository at this point in the history
WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, or any size set in WithDecoderMaxMemory.

This can be used to limit decoding to a specific maximum output size.

Disabled by default.

Fixes #647
  • Loading branch information
klauspost committed Jul 27, 2022
1 parent a3bc126 commit 581d92d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 3 deletions.
31 changes: 30 additions & 1 deletion zstd/decoder.go
Expand Up @@ -355,6 +355,15 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
if debugDecoder {
println("decoder size exceeded:", frame.FrameContentSize, ">", d.o.maxDecodedSize-uint64(len(dst)))
}
return dst, ErrDecoderSizeExceeded
}
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
if debugDecoder {
println("decoder size exceeded:", frame.FrameContentSize, ">", cap(dst)-len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
Expand All @@ -364,7 +373,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
}

if cap(dst) == 0 {
if cap(dst) == 0 && !d.o.limitToCap {
// Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size.
size := len(input) * 2
Expand Down Expand Up @@ -493,6 +502,22 @@ func (d *Decoder) nextBlockSync() (ok bool) {
d.current.err = ErrDecoderSizeExceeded
return false
}
if d.frame.FrameContentSize != fcsUnknown {
if !d.o.limitToCap && d.frame.FrameContentSize > d.o.maxDecodedSize {
if debugDecoder {
println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> mds", d.o.maxDecodedSize)
}
d.current.err = ErrDecoderSizeExceeded
return false
}
if d.o.limitToCap && d.frame.FrameContentSize > uint64(cap(d.frame.history.b)-len(d.frame.history.b)) {
if debugDecoder {
println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> cap", cap(d.frame.history.b)-len(d.frame.history.b))
}
d.current.err = ErrDecoderSizeExceeded
return false
}
}

d.syncStream.decodedFrame = 0
d.syncStream.inFrame = true
Expand Down Expand Up @@ -852,6 +877,10 @@ decodeStream:
}
}
if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
}

err = ErrDecoderSizeExceeded
}
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions zstd/decoder_options.go
Expand Up @@ -20,6 +20,7 @@ type decoderOptions struct {
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
limitToCap bool
}

func (o *decoderOptions) setDefault() {
Expand Down Expand Up @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption {
}
}

// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
// or any size set in WithDecoderMaxMemory.
// This can be used to limit decoding to a specific maximum output size.
// Disabled by default.
func WithDecodeAllCapLimit(b bool) DOption {
return func(o *decoderOptions) error {
o.limitToCap = b
return nil
}
}

// IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error {
Expand Down
32 changes: 32 additions & 0 deletions zstd/decoder_test.go
Expand Up @@ -19,6 +19,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -1900,3 +1901,34 @@ func timeout(after time.Duration) (cancel func()) {
close(cc)
}
}

func TestWithDecodeAllCapLimit(t *testing.T) {
enc, _ := NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10))
dec, _ := NewReader(nil, WithDecodeAllCapLimit(true))
for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 {
sz := sz
t.Run(strconv.Itoa(sz), func(t *testing.T) {
encoded := enc.EncodeAll(make([]byte, sz), nil)
for i := sz - 1; i < sz+1; i++ {
if i < 0 {
continue
}
const existinglen = 5
got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen))
if i < sz {
if err != ErrDecoderSizeExceeded {
t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err)
}
} else {
if err != nil {
t.Errorf("cap: %d, want %v, got %v", i, nil, err)
continue
}
if len(got) != existinglen+i {
t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got))
}
}
}
})
}
}
19 changes: 17 additions & 2 deletions zstd/framedec.go
Expand Up @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
// Store input length, so we only check new data.
crcStart := len(dst)
d.history.decoders.maxSyncLen = 0
if d.o.limitToCap {
d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
}
if d.FrameContentSize != fcsUnknown {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
}
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
}
return dst, ErrDecoderSizeExceeded
}
if uint64(cap(dst)) < d.history.decoders.maxSyncLen {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst)
Expand All @@ -382,6 +393,10 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
err = ErrDecoderSizeExceeded
break
}
if d.o.limitToCap && len(d.history.b) > cap(dst) {
err = ErrDecoderSizeExceeded
break
}
if uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize)
err = ErrFrameSizeExceeded
Expand Down

0 comments on commit 581d92d

Please sign in to comment.