Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: improve header decoder #476

Merged
merged 1 commit into from Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 56 additions & 28 deletions zstd/decodeheader.go
Expand Up @@ -5,6 +5,7 @@ package zstd

import (
"bytes"
"encoding/binary"
"errors"
"io"
)
Expand All @@ -15,18 +16,50 @@ const HeaderMaxSize = 14 + 3

// Header contains information about the first frame and block within that.
type Header struct {
// Window Size the window of data to keep while decoding.
// Will only be set if HasFCS is false.
WindowSize uint64
// SingleSegment specifies whether the data is to be decompressed into a
// single contiguous memory segment.
// It implies that WindowSize is invalid and that FrameContentSize is valid.
SingleSegment bool

// Frame content size.
// Expected size of the entire frame.
FrameContentSize uint64
// WindowSize is the window of data to keep while decoding.
// Will only be set if SingleSegment is false.
WindowSize uint64

// Dictionary ID.
// If 0, no dictionary.
DictionaryID uint32

// HasFCS specifies whether FrameContentSize has a valid value.
HasFCS bool

// FrameContentSize is the expected uncompressed size of the entire frame.
FrameContentSize uint64

// Skippable will be true if the frame is meant to be skipped.
// This implies that FirstBlock.OK is false.
Skippable bool

// SkippableID is the user-specific ID for the skippable frame.
// Valid values are between 0 to 15, inclusive.
SkippableID int

// SkippableSize is the length of the user data to skip following
// the header.
SkippableSize uint32

// HeaderSize is the raw size of the frame header.
//
// For normal frames, it includes the size of the magic number and
// the size of the header (per section 3.1.1.1).
// It does not include the size for any data blocks (section 3.1.1.2) nor
// the size for the trailing content checksum.
//
// For skippable frames, this counts the size of the magic number
// along with the size of the size field of the payload.
// It does not include the size of the skippable payload itself.
// The total frame size is the HeaderSize plus the SkippableSize.
HeaderSize int

// First block information.
FirstBlock struct {
// OK will be set if first block could be decoded.
Expand All @@ -51,17 +84,9 @@ type Header struct {
CompressedSize int
}

// Skippable will be true if the frame is meant to be skipped.
// No other information will be populated.
Skippable bool

// If set there is a checksum present for the block content.
// The checksum field at the end is always 4 bytes long.
HasCheckSum bool

// If this is true FrameContentSize will have a valid value
HasFCS bool

SingleSegment bool
}

// Decode the header from the beginning of the stream.
Expand All @@ -71,39 +96,46 @@ type Header struct {
// If there isn't enough input, io.ErrUnexpectedEOF is returned.
// The FirstBlock.OK will indicate if enough information was available to decode the first block header.
func (h *Header) Decode(in []byte) error {
*h = Header{}
if len(in) < 4 {
return io.ErrUnexpectedEOF
}
h.HeaderSize += 4
b, in := in[:4], in[4:]
if !bytes.Equal(b, frameMagic) {
if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
return ErrMagicMismatch
}
*h = Header{Skippable: true}
if len(in) < 4 {
return io.ErrUnexpectedEOF
}
h.HeaderSize += 4
h.Skippable = true
h.SkippableID = int(b[0] & 0xf)
h.SkippableSize = binary.LittleEndian.Uint32(in)
return nil
}

// Read Window_Descriptor
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
if len(in) < 1 {
return io.ErrUnexpectedEOF
}

// Clear output
*h = Header{}
fhd, in := in[0], in[1:]
h.HeaderSize++
h.SingleSegment = fhd&(1<<5) != 0
h.HasCheckSum = fhd&(1<<2) != 0

if fhd&(1<<3) != 0 {
return errors.New("reserved bit set on frame header")
}

// Read Window_Descriptor
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
if !h.SingleSegment {
if len(in) < 1 {
return io.ErrUnexpectedEOF
}
var wd byte
wd, in = in[0], in[1:]
h.HeaderSize++
windowLog := 10 + (wd >> 3)
windowBase := uint64(1) << windowLog
windowAdd := (windowBase / 8) * uint64(wd&0x7)
Expand All @@ -120,9 +152,7 @@ func (h *Header) Decode(in []byte) error {
return io.ErrUnexpectedEOF
}
b, in = in[:size], in[size:]
if b == nil {
return io.ErrUnexpectedEOF
}
h.HeaderSize += int(size)
switch size {
case 1:
h.DictionaryID = uint32(b[0])
Expand Down Expand Up @@ -152,9 +182,7 @@ func (h *Header) Decode(in []byte) error {
return io.ErrUnexpectedEOF
}
b, in = in[:fcsSize], in[fcsSize:]
if b == nil {
return io.ErrUnexpectedEOF
}
h.HeaderSize += int(fcsSize)
switch fcsSize {
case 1:
h.FrameContentSize = uint64(b[0])
Expand Down
2 changes: 1 addition & 1 deletion zstd/decodeheader_test.go
Expand Up @@ -82,7 +82,7 @@ func TestHeader_Decode(t *testing.T) {
t.Errorf("want error, got result: %v", got)
}
if want != got {
t.Errorf("want %#v, got %#v", want, got)
t.Errorf("header mismatch:\nwant %#v\ngot %#v", want, got)
}
})
}
Expand Down
Binary file modified zstd/testdata/headers-want.json.zst
Binary file not shown.