Skip to content

Commit

Permalink
Add MaxEncodedSize to encoder (#691)
Browse files Browse the repository at this point in the history
Adds function that will return the expected maximum size of a given input size with current settings.

See #688
  • Loading branch information
klauspost committed Nov 15, 2022
1 parent cbc850f commit 6f95269
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
35 changes: 35 additions & 0 deletions zstd/encoder.go
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/rand"
"fmt"
"io"
"math"
rdebug "runtime/debug"
"sync"

Expand Down Expand Up @@ -639,3 +640,37 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte {
}
return dst
}

// MaxEncodedSize returns the expected maximum
// size of an encoded block or stream.
func (e *Encoder) MaxEncodedSize(size int) int {
frameHeader := 4 + 2 // magic + frame header & window descriptor
if e.o.dict != nil {
frameHeader += 4
}
// Frame content size:
if size < 256 {
frameHeader++
} else if size < 65536+256 {
frameHeader += 2
} else if size < math.MaxInt32 {
frameHeader += 4
} else {
frameHeader += 8
}
// Final crc
if e.o.crc {
frameHeader += 4
}

// Max overhead is 3 bytes/block.
// There cannot be 0 blocks.
blocks := (size + e.o.blockSize) / e.o.blockSize

// Combine, add padding.
maxSz := frameHeader + 3*blocks + size
if e.o.pad > 1 {
maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
}
return maxSz
}
22 changes: 17 additions & 5 deletions zstd/encoder_test.go
Expand Up @@ -85,7 +85,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
defer e.Close()
start := time.Now()
dst := e.EncodeAll(in, nil)
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)

Expand All @@ -98,7 +98,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
os.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
t.Fatal("Decoded does not match")
}
t.Log("Encoded content matched")
//t.Log("Encoded content matched")
})
}
}
Expand Down Expand Up @@ -136,6 +136,9 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
go func() {
defer wg.Done()
dst := e.EncodeAll(in, nil)
if len(dst) > e.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
}
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
decoded, err := dec.DecodeAll(dst, nil)
if err != nil {
Expand All @@ -150,7 +153,7 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
}()
}
wg.Wait()
t.Log("Encoded content matched.", n, "goroutines")
//t.Log("Encoded content matched.", n, "goroutines")
})
}
}
Expand Down Expand Up @@ -185,7 +188,10 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
defer e.Close()
start := time.Now()
dst := e.EncodeAll(in, nil)
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
if len(dst) > e.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
}
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)

Expand All @@ -198,7 +204,7 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
t.Error("Decoded does not match")
return
}
t.Log("Encoded content matched")
//t.Log("Encoded content matched")
})
}
}
Expand Down Expand Up @@ -250,6 +256,9 @@ func TestEncoderRegression(t *testing.T) {
t.Error(err)
}
encoded := enc.EncodeAll(in, nil)
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
// Usually too small...
got, err := dec.DecodeAll(encoded, make([]byte, 0, len(in)))
if err != nil {
Expand All @@ -268,6 +277,9 @@ func TestEncoderRegression(t *testing.T) {
t.Error(err)
}
encoded = dst.Bytes()
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2))
if err != nil {
t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got)
Expand Down
13 changes: 13 additions & 0 deletions zstd/fuzz_test.go
Expand Up @@ -210,6 +210,10 @@ func FuzzEncoding(f *testing.F) {
}

encoded := enc.EncodeAll(data, make([]byte, 0, bufSize))
if len(encoded) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
}

got, err := dec.DecodeAll(encoded, make([]byte, 0, bufSize))
if err != nil {
t.Fatal(fmt.Sprintln("Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
Expand All @@ -223,6 +227,9 @@ func FuzzEncoding(f *testing.F) {
t.Fatal(fmt.Sprintln("Level", level, "Close (buffer) error:", err))
}
encoded2 := dst.Bytes()
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
}
if !bytes.Equal(encoded, encoded2) {
got, err = dec.DecodeAll(encoded2, got[:0])
if err != nil {
Expand All @@ -247,6 +254,9 @@ func FuzzEncoding(f *testing.F) {
}

encoded = enc.EncodeAll(data, encoded[:0])
if len(encoded) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
}
got, err = dec.DecodeAll(encoded, got[:0])
if err != nil {
t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
Expand All @@ -260,6 +270,9 @@ func FuzzEncoding(f *testing.F) {
t.Fatal(fmt.Sprintln("Dict Level", level, "Close (buffer) error:", err))
}
encoded2 = dst.Bytes()
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
}
if !bytes.Equal(encoded, encoded2) {
got, err = dec.DecodeAll(encoded2, got[:0])
if err != nil {
Expand Down

0 comments on commit 6f95269

Please sign in to comment.