From 15b48b683237a8ef4c85f262c877d8e1349d5902 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sun, 27 Feb 2022 01:31:07 -0800 Subject: [PATCH] zstd: Add stream encoding without goroutines (#505) * zstd: Add stream encoding without goroutines Does not use goroutines when encoder concurrency is 1. Fixes #264 Can probably be clean up a bit. * Reduce allocs for concurrent buffers when not used. --- zstd/decoder.go | 8 +++-- zstd/encoder.go | 66 +++++++++++++++++++++++++++++++++-------- zstd/encoder_options.go | 1 + zstd/encoder_test.go | 55 ++++++++++++++++++++++++++++++++-- 4 files changed, 113 insertions(+), 17 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 310edbec25..b6f29a5335 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -238,7 +238,9 @@ func (d *Decoder) Reset(r io.Reader) error { // drainOutput will drain the output until errEndOfStream is sent. func (d *Decoder) drainOutput() { if d.current.cancel != nil { - println("cancelling current") + if debugDecoder { + println("cancelling current") + } d.current.cancel() d.current.cancel = nil } @@ -816,7 +818,9 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch do.err = ErrFrameSizeMismatch hasErr = true } else { - println("fcs ok", block.Last, fcs, decodedFrame) + if debugDecoder { + println("fcs ok", block.Last, fcs, decodedFrame) + } } } output <- do diff --git a/zstd/encoder.go b/zstd/encoder.go index e6e315969b..dcc987a7cb 100644 --- a/zstd/encoder.go +++ b/zstd/encoder.go @@ -98,23 +98,25 @@ func (e *Encoder) Reset(w io.Writer) { if cap(s.filling) == 0 { s.filling = make([]byte, 0, e.o.blockSize) } - if cap(s.current) == 0 { - s.current = make([]byte, 0, e.o.blockSize) - } - if cap(s.previous) == 0 { - s.previous = make([]byte, 0, e.o.blockSize) + if e.o.concurrent > 1 { + if cap(s.current) == 0 { + s.current = make([]byte, 0, e.o.blockSize) + } + if cap(s.previous) == 0 { + s.previous = make([]byte, 0, e.o.blockSize) + } + s.current = s.current[:0] + s.previous = s.previous[:0] + if s.writing == nil { + s.writing = &blockEnc{lowMem: e.o.lowMem} + s.writing.init() + } + s.writing.initNewEncode() } if s.encoder == nil { s.encoder = e.o.encoder() } - if s.writing == nil { - s.writing = &blockEnc{lowMem: e.o.lowMem} - s.writing.init() - } - s.writing.initNewEncode() s.filling = s.filling[:0] - s.current = s.current[:0] - s.previous = s.previous[:0] s.encoder.Reset(e.o.dict, false) s.headerWritten = false s.eofWritten = false @@ -258,6 +260,46 @@ func (e *Encoder) nextBlock(final bool) error { return s.err } + // SYNC: + if e.o.concurrent == 1 { + src := s.filling + s.nInput += int64(len(s.filling)) + if debugEncoder { + println("Adding sync block,", len(src), "bytes, final:", final) + } + enc := s.encoder + blk := enc.Block() + blk.reset(nil) + enc.Encode(blk, src) + blk.last = final + if final { + s.eofWritten = true + } + + err := errIncompressible + // If we got the exact same number of literals as input, + // assume the literals cannot be compressed. + if len(src) != len(blk.literals) || len(src) != e.o.blockSize { + err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) + } + switch err { + case errIncompressible: + if debugEncoder { + println("Storing incompressible block as raw") + } + blk.encodeRaw(src) + // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. + case nil: + default: + s.err = err + return err + } + _, s.err = s.w.Write(blk.output) + s.nWritten += int64(len(blk.output)) + s.filling = s.filling[:0] + return s.err + } + // Move blocks forward. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current s.nInput += int64(len(s.current)) diff --git a/zstd/encoder_options.go b/zstd/encoder_options.go index 5f2e1d020e..44d8dbd199 100644 --- a/zstd/encoder_options.go +++ b/zstd/encoder_options.go @@ -76,6 +76,7 @@ func WithEncoderCRC(b bool) EOption { // WithEncoderConcurrency will set the concurrency, // meaning the maximum number of encoders to run concurrently. // The value supplied must be at least 1. +// For streams, setting a value of 1 will disable async compression. // By default this will be set to GOMAXPROCS. func WithEncoderConcurrency(n int) EOption { return func(o *encoderOptions) error { diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 12bf207a15..c913dec23e 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -506,7 +506,7 @@ func testEncoderRoundtrip(t *testing.T, file string, wantCRC []byte) { for _, opt := range getEncOpts(1) { t.Run(opt.name, func(t *testing.T) { opt := opt - t.Parallel() + //t.Parallel() f, err := os.Open(file) if err != nil { if os.IsNotExist(err) { @@ -851,7 +851,7 @@ func TestEncoder_EncodeAllEmpty(t *testing.T) { } func TestEncoder_EncodeAllEnwik9(t *testing.T) { - if false || testing.Short() { + if testing.Short() { t.SkipNow() } file := "testdata/enwik9.zst" @@ -873,8 +873,11 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) { } start := time.Now() - var e Encoder + e, err := NewWriter(nil) dst := e.EncodeAll(in, nil) + if err != nil { + t.Fatal(err) + } t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst)) mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second))) t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec) @@ -889,6 +892,52 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) { t.Log("Encoded content matched") } +func TestEncoder_EncoderStreamEnwik9(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + file := "testdata/enwik9.zst" + f, err := os.Open(file) + if err != nil { + if os.IsNotExist(err) { + t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" + + "compress it with 'zstd -15 -T0 enwik9' and place it in " + file) + } + } + dec, err := NewReader(f) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + in, err := ioutil.ReadAll(dec) + if err != nil { + t.Fatal(err) + } + + start := time.Now() + var dst bytes.Buffer + e, err := NewWriter(&dst) + _, err = io.Copy(e, bytes.NewBuffer(in)) + if err != nil { + t.Fatal(err) + } + e.Close() + t.Log("Full Encoder len", len(in), "-> zstd len", dst.Len()) + mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second))) + t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec) + if false { + decoded, err := dec.DecodeAll(dst.Bytes(), nil) + if err != nil { + t.Error(err, len(decoded)) + } + if !bytes.Equal(decoded, in) { + ioutil.WriteFile("testdata/"+t.Name()+"-enwik9.got", decoded, os.ModePerm) + t.Fatal("Decoded does not match") + } + t.Log("Encoded content matched") + } +} + func BenchmarkEncoder_EncodeAllXML(b *testing.B) { f, err := os.Open("testdata/xml.zst") if err != nil {