Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Add stream encoding without goroutines #505

Merged
merged 2 commits into from Feb 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions zstd/decoder.go
Expand Up @@ -238,7 +238,9 @@ func (d *Decoder) Reset(r io.Reader) error {
// drainOutput will drain the output until errEndOfStream is sent.
func (d *Decoder) drainOutput() {
if d.current.cancel != nil {
println("cancelling current")
if debugDecoder {
println("cancelling current")
}
d.current.cancel()
d.current.cancel = nil
}
Expand Down Expand Up @@ -816,7 +818,9 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
do.err = ErrFrameSizeMismatch
hasErr = true
} else {
println("fcs ok", block.Last, fcs, decodedFrame)
if debugDecoder {
println("fcs ok", block.Last, fcs, decodedFrame)
}
}
}
output <- do
Expand Down
66 changes: 54 additions & 12 deletions zstd/encoder.go
Expand Up @@ -98,23 +98,25 @@ func (e *Encoder) Reset(w io.Writer) {
if cap(s.filling) == 0 {
s.filling = make([]byte, 0, e.o.blockSize)
}
if cap(s.current) == 0 {
s.current = make([]byte, 0, e.o.blockSize)
}
if cap(s.previous) == 0 {
s.previous = make([]byte, 0, e.o.blockSize)
if e.o.concurrent > 1 {
if cap(s.current) == 0 {
s.current = make([]byte, 0, e.o.blockSize)
}
if cap(s.previous) == 0 {
s.previous = make([]byte, 0, e.o.blockSize)
}
s.current = s.current[:0]
s.previous = s.previous[:0]
if s.writing == nil {
s.writing = &blockEnc{lowMem: e.o.lowMem}
s.writing.init()
}
s.writing.initNewEncode()
}
if s.encoder == nil {
s.encoder = e.o.encoder()
}
if s.writing == nil {
s.writing = &blockEnc{lowMem: e.o.lowMem}
s.writing.init()
}
s.writing.initNewEncode()
s.filling = s.filling[:0]
s.current = s.current[:0]
s.previous = s.previous[:0]
s.encoder.Reset(e.o.dict, false)
s.headerWritten = false
s.eofWritten = false
Expand Down Expand Up @@ -258,6 +260,46 @@ func (e *Encoder) nextBlock(final bool) error {
return s.err
}

// SYNC:
if e.o.concurrent == 1 {
src := s.filling
s.nInput += int64(len(s.filling))
if debugEncoder {
println("Adding sync block,", len(src), "bytes, final:", final)
}
enc := s.encoder
blk := enc.Block()
blk.reset(nil)
enc.Encode(blk, src)
blk.last = final
if final {
s.eofWritten = true
}

err := errIncompressible
// If we got the exact same number of literals as input,
// assume the literals cannot be compressed.
if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
}
switch err {
case errIncompressible:
if debugEncoder {
println("Storing incompressible block as raw")
}
blk.encodeRaw(src)
// In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
case nil:
default:
s.err = err
return err
}
_, s.err = s.w.Write(blk.output)
s.nWritten += int64(len(blk.output))
s.filling = s.filling[:0]
return s.err
}

// Move blocks forward.
s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
s.nInput += int64(len(s.current))
Expand Down
1 change: 1 addition & 0 deletions zstd/encoder_options.go
Expand Up @@ -76,6 +76,7 @@ func WithEncoderCRC(b bool) EOption {
// WithEncoderConcurrency will set the concurrency,
// meaning the maximum number of encoders to run concurrently.
// The value supplied must be at least 1.
// For streams, setting a value of 1 will disable async compression.
// By default this will be set to GOMAXPROCS.
func WithEncoderConcurrency(n int) EOption {
return func(o *encoderOptions) error {
Expand Down
55 changes: 52 additions & 3 deletions zstd/encoder_test.go
Expand Up @@ -506,7 +506,7 @@ func testEncoderRoundtrip(t *testing.T, file string, wantCRC []byte) {
for _, opt := range getEncOpts(1) {
t.Run(opt.name, func(t *testing.T) {
opt := opt
t.Parallel()
//t.Parallel()
f, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
Expand Down Expand Up @@ -851,7 +851,7 @@ func TestEncoder_EncodeAllEmpty(t *testing.T) {
}

func TestEncoder_EncodeAllEnwik9(t *testing.T) {
if false || testing.Short() {
if testing.Short() {
t.SkipNow()
}
file := "testdata/enwik9.zst"
Expand All @@ -873,8 +873,11 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) {
}

start := time.Now()
var e Encoder
e, err := NewWriter(nil)
dst := e.EncodeAll(in, nil)
if err != nil {
t.Fatal(err)
}
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
Expand All @@ -889,6 +892,52 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) {
t.Log("Encoded content matched")
}

func TestEncoder_EncoderStreamEnwik9(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
file := "testdata/enwik9.zst"
f, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
}
}
dec, err := NewReader(f)
if err != nil {
t.Fatal(err)
}
defer dec.Close()
in, err := ioutil.ReadAll(dec)
if err != nil {
t.Fatal(err)
}

start := time.Now()
var dst bytes.Buffer
e, err := NewWriter(&dst)
_, err = io.Copy(e, bytes.NewBuffer(in))
if err != nil {
t.Fatal(err)
}
e.Close()
t.Log("Full Encoder len", len(in), "-> zstd len", dst.Len())
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
if false {
decoded, err := dec.DecodeAll(dst.Bytes(), nil)
if err != nil {
t.Error(err, len(decoded))
}
if !bytes.Equal(decoded, in) {
ioutil.WriteFile("testdata/"+t.Name()+"-enwik9.got", decoded, os.ModePerm)
t.Fatal("Decoded does not match")
}
t.Log("Encoded content matched")
}
}

func BenchmarkEncoder_EncodeAllXML(b *testing.B) {
f, err := os.Open("testdata/xml.zst")
if err != nil {
Expand Down