Skip to content

Commit

Permalink
Merge pull request #197 from tie/extraio-tail-reader-ring-buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
tie committed May 28, 2022
2 parents e1c5393 + a33ea59 commit d6c836a
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 127 deletions.
3 changes: 2 additions & 1 deletion extraio/byte.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package extraio
// ByteReader is an io.Reader that reads the same byte indefinitely.
type ByteReader byte

// Read implements io.Reader interface.
// Read implements the io.Reader interface. It fills p with the b byte and
// returns len(p).
func (b ByteReader) Read(p []byte) (n int, err error) {
for i := range p {
p[i] = byte(b)
Expand Down
30 changes: 20 additions & 10 deletions extraio/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,39 @@ package extraio

import (
"io"
"math"
)

// CountReader counts bytes read from the underlying io.Reader.
type CountReader struct {
r io.Reader // underlying reader
n uint64
reader io.Reader
count uint64 // mutable
overflow bool // mutable
}

// NewCountReader returns a new reader that counts bytes read from r.
func NewCountReader(r io.Reader) *CountReader {
return &CountReader{r: r}
return &CountReader{reader: r}
}

// Count returns the count of bytes read.
func (r *CountReader) Count() uint64 {
return r.n
// Count returns the count of bytes read. It returns false if the count of read
// bytes cannot be represented as 64 bit unsigned integer. In practice, that
// would require counting thousands of petabytes to reach this limitation.
func (r *CountReader) Count() (uint64, bool) {
return r.count, !r.overflow
}

// Read implements io.Reader interface. It reads from the underlying io.Reader.
// Read implements the io.Reader interface. It reads from the underlying
// io.Reader and increments read bytes counter.
func (r *CountReader) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
if n > 0 {
r.n += uint64(n)
n, err := r.reader.Read(p)
if n > 0 && !r.overflow {
nn := uint64(n)
r.count += nn
if r.count < nn {
r.overflow = true
r.count = math.MaxUint64
}
}
return n, err
}
17 changes: 10 additions & 7 deletions extraio/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@ import (
"io"
)

// DiscardReader is an io.Reader that discard all reads from the underlying reader.
// DiscardReader is an io.Reader that discard all read bytes from the underlying
// reader.
//
// Note that its Read method returns zero byte count. Some io.Reader client
// implementations return io.ErrNoProgress error when many calls to Read have
// failed to return any data or error.
type DiscardReader struct {
R io.Reader // underlying reader
reader io.Reader
}

// NewDiscardReader returns a new reader that discard all reads from r.
func NewDiscardReader(r io.Reader) *DiscardReader {
return &DiscardReader{r}
}

// Read implements io.Reader interface. It reads from the underlying io.Reader.
// Read implements the io.Reader interface. It reads from the underlying
// io.Reader but always returns zero byte count.
func (d *DiscardReader) Read(p []byte) (int, error) {
_, err := io.Copy(io.Discard, d.R)
if err == nil {
err = io.EOF
}
_, err := d.reader.Read(p)
return 0, err
}
4 changes: 2 additions & 2 deletions extraio/func.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package extraio

// ReaderFunc is an adapter to allow the use of ordinary function as io.Reader.
// If f is a function with appropriate signature, readerFunc(f) is an io.Reader
// If f is a function with appropriate signature, ReaderFunc(f) is an io.Reader
// that calls f.
type ReaderFunc func(p []byte) (n int, err error)

// Read implements io.Reader interface. It calls f(p).
// Read implements the io.Reader interface. It calls f(p).
func (f ReaderFunc) Read(p []byte) (n int, err error) {
return f(p)
}
2 changes: 1 addition & 1 deletion extraio/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func FuzzUnpadReader(f *testing.F) {
}

func FuzzTailReader(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte, n uint64) {
f.Fuzz(func(t *testing.T, data []byte, n uint) {
runTestTailReader(t, data, n)
})
}
41 changes: 25 additions & 16 deletions extraio/hardlimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ import (
// exceeds read limit.
var ErrExceededReadLimit = errors.New("extraio: exceeded read limit")

// HardLimitedReader reads from R but limits the amount of data returned to just
// N bytes.
// HardLimitedReader reads at most n bytes from the underlying reader and
// returns ErrExceededReadLimit if io.EOF is not reached once the limit is
// exceeded.
type HardLimitedReader struct {
R io.Reader // underlying reader
N uint64 // read limit

readCount uint64
reader io.Reader
limit uint64
readCount uint64 // mutable
}

// HardLimitReader returns a Reader that reads from r but stops with an error
// after n bytes.
func HardLimitReader(r io.Reader, n uint64) *HardLimitedReader {
return &HardLimitedReader{
R: r,
N: n,
reader: r,
limit: n,
}
}

Expand All @@ -32,18 +32,27 @@ func (r *HardLimitedReader) Reset() {
r.readCount = 0
}

// Read implements io.Reader interface. It reads from the underlying io.Reader.
// Read implements the io.Reader interface. If the limit has been reached, it
// returns ErrExceededReadLimit error. Otherwise it reads from the underlying
// io.Reader.
func (r *HardLimitedReader) Read(p []byte) (int, error) {
n, err := r.R.Read(p)
if n <= 0 {
return n, err
if r.readCount == r.limit {
return 0, ErrExceededReadLimit
}

nn := uint64(n)
if r.N-r.readCount < nn {
return n, ErrExceededReadLimit
// Do not read more than the remaining limit. This also guarantees that
// n will not overflow or exceed limit on addition to readCount.
limit := r.limit - r.readCount
if uint64(len(p)) > limit {
p = p[:limit]
}
r.readCount += nn

n, err := r.reader.Read(p)
if n > 0 {
r.readCount += uint64(n)
}
if err == io.EOF {
return n, err
}
return n, err
}
11 changes: 9 additions & 2 deletions extraio/hardlimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,17 @@ func TestHardLimitReader(t *testing.T) {
}
for _, tc := range testCases {
data, err := io.ReadAll(HardLimitReader(bytes.NewReader(tc.Data), tc.Limit))
if uint64(len(tc.Data)) <= tc.Limit {
switch {
// Special case: read returns (n, io.EOF) and together with
// previous reads we reach read limit but stop reading due to
// the explicit EOF.
case len(data) == len(tc.Data) && err == nil:
fallthrough
// Must succeed if data does not exceed the limit.
case uint64(len(tc.Data)) < tc.Limit:
assert.NilError(t, err)
assert.DeepEqual(t, tc.Data, data)
} else {
default:
assert.ErrorIs(t, err, ErrExceededReadLimit)
}
}
Expand Down
51 changes: 38 additions & 13 deletions extraio/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,56 @@ import (
// hash state does not change after the first Read call.
type HashReader struct {
h hash.Hash
p []byte

buf []byte // mutable
off int // mutable
}

// NewStrippedHashReader returns a new io.Reader that reads at most n byte of
// the given hash.
// NewStrippedHashReader returns a new io.Reader that reads at most n bytes of
// the given hash function.
func NewStrippedHashReader(h hash.Hash, n int64) io.Reader {
return io.LimitReader(NewHashReader(h), n)
}

// NewHashReader returns a new reader that reads from the given hash.
// NewHashReader returns a new reader that reads from the given hash function.
func NewHashReader(h hash.Hash) *HashReader {
return &HashReader{h: h}
buf := make([]byte, 0, h.Size())
return &HashReader{
h: h,
buf: buf,
}
}

// Hash returns the underlying hash function.
func (r *HashReader) Hash() hash.Hash {
return r.h
}

// Reset resets the reader’s state. It does not reset the state of the
// underlying hash function.
func (r *HashReader) Reset() {
r.buf = r.buf[:0]
r.off = 0
}

// Read implements io.Reader interface. It computes the hash on the first call
// and advances through the hash buffer on subsequent calls to Read.
func (h *HashReader) Read(p []byte) (int, error) {
if h.p == nil {
h.p = h.h.Sum(nil)
// Read implements the io.Reader interface. It computes the hash on the first
// call and advances through the hash buffer on subsequent calls to Read.
func (r *HashReader) Read(p []byte) (int, error) {
if len(r.buf) == 0 {
r.buf = r.h.Sum(r.buf)
}

size := cap(r.buf)
if size == r.off {
return 0, io.EOF
}

n := copy(p, h.p)
h.p = h.p[n:]
if len(h.p) == 0 {
n := copy(p, r.buf[r.off:])
r.off += n

if size == r.off {
return n, io.EOF
}

return n, nil
}
43 changes: 43 additions & 0 deletions extraio/hash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package extraio

import (
"crypto/sha256"
"io"
"testing"
"testing/iotest"

"gotest.tools/v3/assert"
)

func TestHashReader(t *testing.T) {
data := []byte("test")
hash := sha256ToSlice(sha256.Sum256(data))

h := NewHashReader(sha256.New())

_, err := h.Hash().Write(data)
assert.NilError(t, err)

err = iotest.TestReader(h, hash)
assert.NilError(t, err)

h.Reset()

err = iotest.TestReader(h, hash)
assert.NilError(t, err)

h.Hash().Reset()

_, err = h.Read(nil)
assert.ErrorIs(t, err, io.EOF)

h.Reset()

hash = sha256ToSlice(sha256.Sum256(nil))
err = iotest.TestReader(h, hash)
assert.NilError(t, err)
}

func sha256ToSlice(b [sha256.Size]byte) []byte {
return b[:]
}
34 changes: 24 additions & 10 deletions extraio/pad.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package extraio

import (
"errors"
"io"
)

Expand All @@ -11,9 +10,9 @@ type PadReader struct {
reader io.Reader
blockSize uint8

readCount uint64
padding uint8
fillByte byte
incomplete int // mutable
padding uint8 // mutable
fillByte byte // mutable
}

// NewPadReader returns a new reader that pads r with the given block size.
Expand All @@ -26,7 +25,15 @@ func NewPadReader(r io.Reader, blockSize uint8) *PadReader {
}
}

// Read implements io.Reader interface. It reads from the underlying io.Reader.
// Reset resets the reader’s state.
func (r *PadReader) Reset() {
r.incomplete = 0
r.padding = 0
r.fillByte = 0
}

// Read implements the io.Reader interface. It reads from the underlying
// io.Reader until EOF and then writes padding into the read buffer.
func (r *PadReader) Read(p []byte) (int, error) {
if r.fillByte != 0 {
return r.pad(p)
Expand All @@ -37,13 +44,15 @@ func (r *PadReader) Read(p []byte) (int, error) {
return n, err
}
if n > 0 {
r.readCount += uint64(n)
bs := int(r.blockSize)
r.incomplete += n % bs
r.incomplete %= bs
}
if !errors.Is(err, io.EOF) {
if err != io.EOF {
return n, err
}

r.padding = r.blockSize - uint8(r.readCount%uint64(r.blockSize))
r.padding = r.blockSize - uint8(r.incomplete)
r.fillByte = byte(r.padding)

nn, err := r.pad(p[n:])
Expand All @@ -59,17 +68,22 @@ func (r *PadReader) pad(p []byte) (int, error) {
var err error

n := len(p)
if int(r.padding) <= n {
n = int(r.padding)
if k := int(r.padding); k <= n {
n = k
r.padding = 0
err = io.EOF
} else {
// Note that !(k <= n) means that n < k <= math.MaxUint8.
r.padding -= uint8(n)
}

for i := 0; i < n; i++ {
p[i] = r.fillByte
}

if r.padding == 0 {
r.Reset()
}

return n, err
}

0 comments on commit d6c836a

Please sign in to comment.