diff --git a/s2/decode.go b/s2/decode.go index 2aba9e272d..b5fa4d3f00 100644 --- a/s2/decode.go +++ b/s2/decode.go @@ -791,6 +791,7 @@ func (r *Reader) Skip(n int64) error { } else { // Skip block completely n -= int64(dLen) + r.blockStart += int64(dLen) dLen = 0 } r.i, r.j = 0, dLen @@ -921,6 +922,15 @@ func (r *Reader) ReadSeeker(random bool, index []byte) (*ReadSeeker, error) { err = r.index.LoadStream(rs) if err != nil { if err == ErrUnsupported { + // If we don't require random seeking, reset input and return. + if !random { + _, err = rs.Seek(pos, io.SeekStart) + if err != nil { + return nil, ErrCantSeek{Reason: "resetting stream returned: " + err.Error()} + } + r.index = nil + return &ReadSeeker{Reader: r}, nil + } return nil, ErrCantSeek{Reason: "input stream does not contain an index"} } return nil, ErrCantSeek{Reason: "reading index returned: " + err.Error()} diff --git a/s2/index_test.go b/s2/index_test.go index 38975df21b..9a43c2be74 100644 --- a/s2/index_test.go +++ b/s2/index_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "math/rand" "sync" + "testing" "github.com/klauspost/compress/s2" ) @@ -99,3 +100,126 @@ func ExampleIndex_Load() { //Successfully skipped forward to 4444440 //Successfully skipped forward to 4999995 } + +func TestSeeking(t *testing.T) { + compressed := bytes.Buffer{} + + // Use small blocks so there are plenty of them. + enc := s2.NewWriter(&compressed, s2.WriterBlockSize(16<<10)) + var nElems = 1_000_000 + var testSizes = []int{100, 1_000, 10_000, 20_000, 100_000, 200_000, 400_000} + if testing.Short() { + nElems = 100_000 + testSizes = []int{100, 1_000, 10_000, 20_000} + } + testSizes = append(testSizes, nElems-1) + //24 bytes per item plus \n = 25 bytes per record + for i := 0; i < nElems; i++ { + fmt.Fprintf(enc, "Item %019d\n", i) + } + + index, err := enc.CloseIndex() + if err != nil { + t.Fatal(err) + } + + for _, skip := range testSizes { + t.Run(fmt.Sprintf("noSeekSkip=%d", skip), func(t *testing.T) { + dec := s2.NewReader(io.NopCloser(bytes.NewReader(compressed.Bytes()))) + seeker, err := dec.ReadSeeker(false, nil) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + for rec := 0; rec < nElems; rec += skip { + offset := int64(rec * 25) + //t.Logf("Reading record %d", rec) + _, err := seeker.Seek(offset, io.SeekStart) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + _, err = io.ReadFull(dec, buf) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + } + }) + t.Run(fmt.Sprintf("seekSkip=%d", skip), func(t *testing.T) { + dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes()))) + seeker, err := dec.ReadSeeker(false, nil) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + for rec := 0; rec < nElems; rec += skip { + offset := int64(rec * 25) + //t.Logf("Reading record %d", rec) + _, err := seeker.Seek(offset, io.SeekStart) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + _, err = io.ReadFull(dec, buf) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + } + }) + t.Run(fmt.Sprintf("noSeekIndexSkip=%d", skip), func(t *testing.T) { + dec := s2.NewReader(io.NopCloser(bytes.NewReader(compressed.Bytes()))) + seeker, err := dec.ReadSeeker(false, index) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + for rec := 0; rec < nElems; rec += skip { + offset := int64(rec * 25) + //t.Logf("Reading record %d", rec) + _, err := seeker.Seek(offset, io.SeekStart) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + _, err = io.ReadFull(dec, buf) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + } + }) + t.Run(fmt.Sprintf("seekIndexSkip=%d", skip), func(t *testing.T) { + dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes()))) + + seeker, err := dec.ReadSeeker(false, index) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + for rec := 0; rec < nElems; rec += skip { + offset := int64(rec * 25) + //t.Logf("Reading record %d", rec) + _, err := seeker.Seek(offset, io.SeekStart) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + _, err = io.ReadFull(dec, buf) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + } + }) + } +}