Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Add MaxEncodedSize to encoder #691

Merged
merged 1 commit into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions zstd/encoder.go
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/rand"
"fmt"
"io"
"math"
rdebug "runtime/debug"
"sync"

Expand Down Expand Up @@ -639,3 +640,37 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte {
}
return dst
}

// MaxEncodedSize returns the expected maximum
// size of an encoded block or stream.
func (e *Encoder) MaxEncodedSize(size int) int {
frameHeader := 4 + 2 // magic + frame header & window descriptor
if e.o.dict != nil {
frameHeader += 4
}
// Frame content size:
if size < 256 {
frameHeader++
} else if size < 65536+256 {
frameHeader += 2
} else if size < math.MaxInt32 {
frameHeader += 4
} else {
frameHeader += 8
}
// Final crc
if e.o.crc {
frameHeader += 4
}

// Max overhead is 3 bytes/block.
// There cannot be 0 blocks.
blocks := (size + e.o.blockSize) / e.o.blockSize

// Combine, add padding.
maxSz := frameHeader + 3*blocks + size
if e.o.pad > 1 {
maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
}
return maxSz
}
22 changes: 17 additions & 5 deletions zstd/encoder_test.go
Expand Up @@ -85,7 +85,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
defer e.Close()
start := time.Now()
dst := e.EncodeAll(in, nil)
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)

Expand All @@ -98,7 +98,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
os.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
t.Fatal("Decoded does not match")
}
t.Log("Encoded content matched")
//t.Log("Encoded content matched")
})
}
}
Expand Down Expand Up @@ -136,6 +136,9 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
go func() {
defer wg.Done()
dst := e.EncodeAll(in, nil)
if len(dst) > e.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
}
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
decoded, err := dec.DecodeAll(dst, nil)
if err != nil {
Expand All @@ -150,7 +153,7 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
}()
}
wg.Wait()
t.Log("Encoded content matched.", n, "goroutines")
//t.Log("Encoded content matched.", n, "goroutines")
})
}
}
Expand Down Expand Up @@ -185,7 +188,10 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
defer e.Close()
start := time.Now()
dst := e.EncodeAll(in, nil)
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
if len(dst) > e.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
}
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)

Expand All @@ -198,7 +204,7 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
t.Error("Decoded does not match")
return
}
t.Log("Encoded content matched")
//t.Log("Encoded content matched")
})
}
}
Expand Down Expand Up @@ -250,6 +256,9 @@ func TestEncoderRegression(t *testing.T) {
t.Error(err)
}
encoded := enc.EncodeAll(in, nil)
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
// Usually too small...
got, err := dec.DecodeAll(encoded, make([]byte, 0, len(in)))
if err != nil {
Expand All @@ -268,6 +277,9 @@ func TestEncoderRegression(t *testing.T) {
t.Error(err)
}
encoded = dst.Bytes()
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2))
if err != nil {
t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got)
Expand Down
13 changes: 13 additions & 0 deletions zstd/fuzz_test.go
Expand Up @@ -210,6 +210,10 @@ func FuzzEncoding(f *testing.F) {
}

encoded := enc.EncodeAll(data, make([]byte, 0, bufSize))
if len(encoded) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
}

got, err := dec.DecodeAll(encoded, make([]byte, 0, bufSize))
if err != nil {
t.Fatal(fmt.Sprintln("Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
Expand All @@ -223,6 +227,9 @@ func FuzzEncoding(f *testing.F) {
t.Fatal(fmt.Sprintln("Level", level, "Close (buffer) error:", err))
}
encoded2 := dst.Bytes()
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
}
if !bytes.Equal(encoded, encoded2) {
got, err = dec.DecodeAll(encoded2, got[:0])
if err != nil {
Expand All @@ -247,6 +254,9 @@ func FuzzEncoding(f *testing.F) {
}

encoded = enc.EncodeAll(data, encoded[:0])
if len(encoded) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
}
got, err = dec.DecodeAll(encoded, got[:0])
if err != nil {
t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
Expand All @@ -260,6 +270,9 @@ func FuzzEncoding(f *testing.F) {
t.Fatal(fmt.Sprintln("Dict Level", level, "Close (buffer) error:", err))
}
encoded2 = dst.Bytes()
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
}
if !bytes.Equal(encoded, encoded2) {
got, err = dec.DecodeAll(encoded2, got[:0])
if err != nil {
Expand Down