Skip to content

Commit

Permalink
zstd: Add DecodeAllCapLimit (#649)
Browse files Browse the repository at this point in the history
* zstd: Add DecodeAllCapLimit

WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, or any size set in WithDecoderMaxMemory.

This can be used to limit decoding to a specific maximum output size.

Disabled by default.

Fixes #647
  • Loading branch information
klauspost committed Jul 29, 2022
1 parent 6234e33 commit 5a3a4a9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 5 deletions.
21 changes: 19 additions & 2 deletions zstd/decoder.go
Expand Up @@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
// Grab a block decoder and frame decoder.
block := <-d.decoders
frame := block.localFrame
initialSize := len(dst)
defer func() {
if debugDecoder {
printf("re-adding decoder: %p", block)
Expand Down Expand Up @@ -354,7 +355,16 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
return dst, ErrWindowSizeExceeded
}
if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
Expand All @@ -364,7 +374,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
}

if cap(dst) == 0 {
if cap(dst) == 0 && !d.o.limitToCap {
// Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size.
size := len(input) * 2
Expand All @@ -382,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
if err != nil {
return dst, err
}
if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
return dst, ErrDecoderSizeExceeded
}
if len(frame.bBuf) == 0 {
if debugDecoder {
println("frame dbuf empty")
Expand Down Expand Up @@ -852,6 +865,10 @@ decodeStream:
}
}
if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
}

err = ErrDecoderSizeExceeded
}
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions zstd/decoder_options.go
Expand Up @@ -20,6 +20,7 @@ type decoderOptions struct {
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
limitToCap bool
}

func (o *decoderOptions) setDefault() {
Expand Down Expand Up @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption {
}
}

// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
// or any size set in WithDecoderMaxMemory.
// This can be used to limit decoding to a specific maximum output size.
// Disabled by default.
func WithDecodeAllCapLimit(b bool) DOption {
return func(o *decoderOptions) error {
o.limitToCap = b
return nil
}
}

// IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error {
Expand Down
54 changes: 54 additions & 0 deletions zstd/decoder_test.go
Expand Up @@ -19,6 +19,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -1900,3 +1901,56 @@ func timeout(after time.Duration) (cancel func()) {
close(cc)
}
}

func TestWithDecodeAllCapLimit(t *testing.T) {
var encs []*Encoder
var decs []*Decoder
addEnc := func(e *Encoder, _ error) {
encs = append(encs, e)
}
addDec := func(d *Decoder, _ error) {
decs = append(decs, d)
}
addEnc(NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithEncoderConcurrency(1), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithZeroFrames(false), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithWindowSize(128<<10)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderConcurrency(1)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderLowmem(true)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxWindow(128<<10)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxMemory(1<<20)))
for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 {
sz := sz
t.Run(strconv.Itoa(sz), func(t *testing.T) {
t.Parallel()
for ei, enc := range encs {
for di, dec := range decs {
t.Run(fmt.Sprintf("e%d:d%d", ei, di), func(t *testing.T) {
encoded := enc.EncodeAll(make([]byte, sz), nil)
for i := sz - 1; i < sz+1; i++ {
if i < 0 {
continue
}
const existinglen = 5
got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen))
if i < sz {
if err != ErrDecoderSizeExceeded {
t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err)
}
} else {
if err != nil {
t.Errorf("cap: %d, want %v, got %v", i, nil, err)
continue
}
if len(got) != existinglen+i {
t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got))
}
}
}
})
}
}
})
}
}
23 changes: 20 additions & 3 deletions zstd/framedec.go
Expand Up @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
// Store input length, so we only check new data.
crcStart := len(dst)
d.history.decoders.maxSyncLen = 0
if d.o.limitToCap {
d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
}
if d.FrameContentSize != fcsUnknown {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
}
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
}
return dst, ErrDecoderSizeExceeded
}
if uint64(cap(dst)) < d.history.decoders.maxSyncLen {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst)
Expand All @@ -378,7 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if err != nil {
break
}
if uint64(len(d.history.b)) > d.o.maxDecodedSize {
if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
err = ErrDecoderSizeExceeded
break
}
if d.o.limitToCap && len(d.history.b) > cap(dst) {
println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
err = ErrDecoderSizeExceeded
break
}
Expand Down

0 comments on commit 5a3a4a9

Please sign in to comment.