Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Add DecodeAllCapLimit #649

Merged
merged 3 commits into from Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 19 additions & 2 deletions zstd/decoder.go
Expand Up @@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
// Grab a block decoder and frame decoder.
block := <-d.decoders
frame := block.localFrame
initialSize := len(dst)
defer func() {
if debugDecoder {
printf("re-adding decoder: %p", block)
Expand Down Expand Up @@ -354,7 +355,16 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
return dst, ErrWindowSizeExceeded
}
if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
Expand All @@ -364,7 +374,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
}

if cap(dst) == 0 {
if cap(dst) == 0 && !d.o.limitToCap {
// Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size.
size := len(input) * 2
Expand All @@ -382,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
if err != nil {
return dst, err
}
if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
return dst, ErrDecoderSizeExceeded
}
if len(frame.bBuf) == 0 {
if debugDecoder {
println("frame dbuf empty")
Expand Down Expand Up @@ -852,6 +865,10 @@ decodeStream:
}
}
if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
}

err = ErrDecoderSizeExceeded
}
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions zstd/decoder_options.go
Expand Up @@ -20,6 +20,7 @@ type decoderOptions struct {
maxWindowSize uint64
dicts []dict
ignoreChecksum bool
limitToCap bool
}

func (o *decoderOptions) setDefault() {
Expand Down Expand Up @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption {
}
}

// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
// or any size set in WithDecoderMaxMemory.
// This can be used to limit decoding to a specific maximum output size.
// Disabled by default.
func WithDecodeAllCapLimit(b bool) DOption {
return func(o *decoderOptions) error {
o.limitToCap = b
return nil
}
}

// IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error {
Expand Down
54 changes: 54 additions & 0 deletions zstd/decoder_test.go
Expand Up @@ -19,6 +19,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -1900,3 +1901,56 @@ func timeout(after time.Duration) (cancel func()) {
close(cc)
}
}

func TestWithDecodeAllCapLimit(t *testing.T) {
var encs []*Encoder
var decs []*Decoder
addEnc := func(e *Encoder, _ error) {
encs = append(encs, e)
}
addDec := func(d *Decoder, _ error) {
decs = append(decs, d)
}
addEnc(NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithEncoderConcurrency(1), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithZeroFrames(false), WithWindowSize(4<<10)))
addEnc(NewWriter(nil, WithWindowSize(128<<10)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderConcurrency(1)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderLowmem(true)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxWindow(128<<10)))
addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxMemory(1<<20)))
for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 {
sz := sz
t.Run(strconv.Itoa(sz), func(t *testing.T) {
t.Parallel()
for ei, enc := range encs {
for di, dec := range decs {
t.Run(fmt.Sprintf("e%d:d%d", ei, di), func(t *testing.T) {
encoded := enc.EncodeAll(make([]byte, sz), nil)
for i := sz - 1; i < sz+1; i++ {
if i < 0 {
continue
}
const existinglen = 5
got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen))
if i < sz {
if err != ErrDecoderSizeExceeded {
t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err)
}
} else {
if err != nil {
t.Errorf("cap: %d, want %v, got %v", i, nil, err)
continue
}
if len(got) != existinglen+i {
t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got))
}
}
}
})
}
}
})
}
}
23 changes: 20 additions & 3 deletions zstd/framedec.go
Expand Up @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
// Store input length, so we only check new data.
crcStart := len(dst)
d.history.decoders.maxSyncLen = 0
if d.o.limitToCap {
d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
}
if d.FrameContentSize != fcsUnknown {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
}
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
}
return dst, ErrDecoderSizeExceeded
}
if uint64(cap(dst)) < d.history.decoders.maxSyncLen {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst)
Expand All @@ -378,7 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if err != nil {
break
}
if uint64(len(d.history.b)) > d.o.maxDecodedSize {
if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
err = ErrDecoderSizeExceeded
break
}
if d.o.limitToCap && len(d.history.b) > cap(dst) {
println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
err = ErrDecoderSizeExceeded
break
}
Expand Down