Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2: Fix absolute forward seeks #633

Merged
merged 2 commits into from Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions s2/decode.go
Expand Up @@ -791,6 +791,7 @@ func (r *Reader) Skip(n int64) error {
} else {
// Skip block completely
n -= int64(dLen)
r.blockStart += int64(dLen)
dLen = 0
}
r.i, r.j = 0, dLen
Expand Down Expand Up @@ -921,6 +922,15 @@ func (r *Reader) ReadSeeker(random bool, index []byte) (*ReadSeeker, error) {
err = r.index.LoadStream(rs)
if err != nil {
if err == ErrUnsupported {
// If we don't require random seeking, reset input and return.
if !random {
_, err = rs.Seek(pos, io.SeekStart)
if err != nil {
return nil, ErrCantSeek{Reason: "resetting stream returned: " + err.Error()}
}
r.index = nil
return &ReadSeeker{Reader: r}, nil
}
return nil, ErrCantSeek{Reason: "input stream does not contain an index"}
}
return nil, ErrCantSeek{Reason: "reading index returned: " + err.Error()}
Expand Down
124 changes: 124 additions & 0 deletions s2/index_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"math/rand"
"sync"
"testing"

"github.com/klauspost/compress/s2"
)
Expand Down Expand Up @@ -99,3 +100,126 @@ func ExampleIndex_Load() {
//Successfully skipped forward to 4444440
//Successfully skipped forward to 4999995
}

func TestSeeking(t *testing.T) {
compressed := bytes.Buffer{}

// Use small blocks so there are plenty of them.
enc := s2.NewWriter(&compressed, s2.WriterBlockSize(16<<10))
var nElems = 1_000_000
var testSizes = []int{100, 1_000, 10_000, 20_000, 100_000, 200_000, 400_000}
if testing.Short() {
nElems = 100_000
testSizes = []int{100, 1_000, 10_000, 20_000}
}
testSizes = append(testSizes, nElems-1)
//24 bytes per item plus \n = 25 bytes per record
for i := 0; i < nElems; i++ {
fmt.Fprintf(enc, "Item %019d\n", i)
}

index, err := enc.CloseIndex()
if err != nil {
t.Fatal(err)
}

for _, skip := range testSizes {
t.Run(fmt.Sprintf("noSeekSkip=%d", skip), func(t *testing.T) {
dec := s2.NewReader(io.NopCloser(bytes.NewReader(compressed.Bytes())))
seeker, err := dec.ReadSeeker(false, nil)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
for rec := 0; rec < nElems; rec += skip {
offset := int64(rec * 25)
//t.Logf("Reading record %d", rec)
_, err := seeker.Seek(offset, io.SeekStart)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
_, err = io.ReadFull(dec, buf)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
}
})
t.Run(fmt.Sprintf("seekSkip=%d", skip), func(t *testing.T) {
dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes())))
seeker, err := dec.ReadSeeker(false, nil)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
for rec := 0; rec < nElems; rec += skip {
offset := int64(rec * 25)
//t.Logf("Reading record %d", rec)
_, err := seeker.Seek(offset, io.SeekStart)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
_, err = io.ReadFull(dec, buf)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
}
})
t.Run(fmt.Sprintf("noSeekIndexSkip=%d", skip), func(t *testing.T) {
dec := s2.NewReader(io.NopCloser(bytes.NewReader(compressed.Bytes())))
seeker, err := dec.ReadSeeker(false, index)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
for rec := 0; rec < nElems; rec += skip {
offset := int64(rec * 25)
//t.Logf("Reading record %d", rec)
_, err := seeker.Seek(offset, io.SeekStart)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
_, err = io.ReadFull(dec, buf)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
}
})
t.Run(fmt.Sprintf("seekIndexSkip=%d", skip), func(t *testing.T) {
dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes())))

seeker, err := dec.ReadSeeker(false, index)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 25)
for rec := 0; rec < nElems; rec += skip {
offset := int64(rec * 25)
//t.Logf("Reading record %d", rec)
_, err := seeker.Seek(offset, io.SeekStart)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
_, err = io.ReadFull(dec, buf)
if err != nil {
t.Fatalf("Failed to seek: %v", err)
}
expected := fmt.Sprintf("Item %019d\n", rec)
if string(buf) != expected {
t.Fatalf("Expected %q, got %q", expected, buf)
}
}
})
}
}