Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor zstd decoder #498

Merged
merged 35 commits into from Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1a14c30
Refactor zstd decoder
klauspost Jan 18, 2022
b5ab519
Merge branch 'master' into refactor-zstd-decoder
klauspost Feb 9, 2022
984fc8f
Make it compile
klauspost Feb 11, 2022
dca58e6
Fix up single decodes.
klauspost Feb 11, 2022
f01cd33
Merge branch 'master' into refactor-zstd-decoder
klauspost Feb 14, 2022
3c54049
Almost working now...
klauspost Feb 16, 2022
8bf62b7
Tests pass.
klauspost Feb 18, 2022
ac20ec2
Avoid a few allocs
klauspost Feb 18, 2022
9a9ac6b
Add stream decompression with no goroutines.
klauspost Feb 18, 2022
b2951ef
Check FrameContentSize and max decoded size.
klauspost Feb 18, 2022
9d525e5
Remove unused var+func.
klauspost Feb 18, 2022
b74ecce
Tweaks and cleanup
klauspost Feb 18, 2022
e217e78
Use maxsize as documented.
klauspost Feb 19, 2022
6598a07
Merge branch 'master' into refactor-zstd-decoder
klauspost Feb 21, 2022
561b94c
Ensure history from frames cannot overlap.
klauspost Feb 21, 2022
3db2dcb
Fix deadlock on error.
klauspost Feb 21, 2022
55e4c4b
Stricter framecontent size checks and consistency.
klauspost Feb 21, 2022
d6e790a
Fix short test.
klauspost Feb 21, 2022
8b43a92
Add bench
klauspost Feb 22, 2022
d0ce155
Merge branch 'master' into refactor-zstd-decoder
klauspost Feb 22, 2022
50a135c
Reject big RLE/RAW blocks as per https://github.com/facebook/zstd/iss…
klauspost Feb 22, 2022
d731396
Use os.FileInfo for Go 1.15.
klauspost Feb 22, 2022
03b301a
Break all on errors
klauspost Feb 22, 2022
4c5a306
Move sync code to separate method.
klauspost Feb 22, 2022
d3974d3
Don't read sent error.
klauspost Feb 22, 2022
133d52c
Fix consistent error reporting and dict inits.
klauspost Feb 22, 2022
2b89e67
Fix error message
klauspost Feb 22, 2022
606373a
Check if huff0 X4 blocks match size exactly.
klauspost Feb 23, 2022
a197f86
Fix decoder leakage.
klauspost Feb 23, 2022
d1f91e2
Forward blocks, so we don't risk run-away decoding.
klauspost Feb 23, 2022
e7906eb
Fix test race
klauspost Feb 23, 2022
7a09f3d
Save last before sending.
klauspost Feb 23, 2022
f5e0961
Reuse history between async calls.
klauspost Feb 23, 2022
6cabc28
Protect local frame.
klauspost Feb 23, 2022
f20d563
Clarify error msg
klauspost Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 15 additions & 106 deletions huff0/bitreader.go
Expand Up @@ -8,115 +8,10 @@ package huff0
import (
"encoding/binary"
"errors"
"fmt"
"io"
)

// bitReader reads a bitstream in reverse.
// The last set bit indicates the start of the stream and is used
// for aligning the input.
type bitReader struct {
in []byte
off uint // next byte to read is at in[off - 1]
value uint64
bitsRead uint8
}

// init initializes and resets the bit reader.
func (b *bitReader) init(in []byte) error {
if len(in) < 1 {
return errors.New("corrupt stream: too short")
}
b.in = in
b.off = uint(len(in))
// The highest bit of the last byte indicates where to start
v := in[len(in)-1]
if v == 0 {
return errors.New("corrupt stream, did not find end of stream")
}
b.bitsRead = 64
b.value = 0
if len(in) >= 8 {
b.fillFastStart()
} else {
b.fill()
b.fill()
}
b.bitsRead += 8 - uint8(highBit32(uint32(v)))
return nil
}

// peekBitsFast requires that at least one bit is requested every time.
// There are no checks if the buffer is filled.
func (b *bitReader) peekBitsFast(n uint8) uint16 {
const regMask = 64 - 1
v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
return v
}

// fillFast() will make sure at least 32 bits are available.
// There must be at least 4 bytes available.
func (b *bitReader) fillFast() {
if b.bitsRead < 32 {
return
}

// 2 bounds checks.
v := b.in[b.off-4 : b.off]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.bitsRead -= 32
b.off -= 4
}

func (b *bitReader) advance(n uint8) {
b.bitsRead += n
}

// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
func (b *bitReader) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.bitsRead = 0
b.off -= 8
}

// fill() will make sure at least 32 bits are available.
func (b *bitReader) fill() {
if b.bitsRead < 32 {
return
}
if b.off > 4 {
v := b.in[b.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.bitsRead -= 32
b.off -= 4
return
}
for b.off > 0 {
b.value = (b.value << 8) | uint64(b.in[b.off-1])
b.bitsRead -= 8
b.off--
}
}

// finished returns true if all bits have been read from the bit stream.
func (b *bitReader) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReader) close() error {
// Release reference.
b.in = nil
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
return nil
}

// bitReader reads a bitstream in reverse.
// The last set bit indicates the start of the stream and is used
// for aligning the input.
Expand Down Expand Up @@ -213,10 +108,17 @@ func (b *bitReaderBytes) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

func (b *bitReaderBytes) remaining() uint {
return b.off*8 + uint(64-b.bitsRead)
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReaderBytes) close() error {
// Release reference.
b.in = nil
if b.remaining() > 0 {
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down Expand Up @@ -318,10 +220,17 @@ func (b *bitReaderShifted) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

func (b *bitReaderShifted) remaining() uint {
return b.off*8 + uint(64-b.bitsRead)
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReaderShifted) close() error {
// Release reference.
b.in = nil
if b.remaining() > 0 {
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down
63 changes: 38 additions & 25 deletions huff0/decompress.go
Expand Up @@ -741,6 +741,7 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
}

var br [4]bitReaderShifted
// Decode "jump table"
start := 6
for i := 0; i < 3; i++ {
length := int(src[i*2]) | (int(src[i*2+1]) << 8)
Expand Down Expand Up @@ -865,30 +866,18 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := br.off*8 + uint(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
br.fill()
if false && br.bitsRead >= 32 {
if br.off >= 4 {
v := br.in[br.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
br.value = (br.value << 32) | uint64(low)
br.bitsRead -= 32
br.off -= 4
} else {
for br.off > 0 {
br.value = (br.value << 8) | uint64(br.in[br.off-1])
br.bitsRead -= 8
br.off--
}
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -902,6 +891,10 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down Expand Up @@ -1091,10 +1084,16 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
if br.finished() {
d.bufs.Put(buf)
Expand All @@ -1117,7 +1116,7 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -1126,10 +1125,14 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
v := single[uint8(br.value>>shift)].entry
nBits := uint8(v)
br.advance(nBits)
bitsLeft -= int(nBits)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down Expand Up @@ -1315,10 +1318,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
if br.finished() {
d.bufs.Put(buf)
Expand All @@ -1341,7 +1349,7 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -1350,10 +1358,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
v := single[br.peekByteFast()].entry
nBits := uint8(v)
br.advance(nBits)
bitsLeft -= int(nBits)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}

decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions zstd/bitreader.go
Expand Up @@ -7,6 +7,7 @@ package zstd
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math/bits"
)
Expand Down Expand Up @@ -132,6 +133,9 @@ func (b *bitReader) remain() uint {
func (b *bitReader) close() error {
// Release reference.
b.in = nil
if !b.finished() {
return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down