Skip to content

Commit

Permalink
transport/http: Add fix for HTTP request stream not seekable
Browse files Browse the repository at this point in the history
Adds a fix for the HTTP Request setStream to correctly handle empty
and empty unseekable teams. This fixes a bug where the HTTP request
would attempt to rewind an empty unseekable stream due to assumptions
that the stream would be nil if content length is 0.

Related to #356
  • Loading branch information
jasdel committed Mar 8, 2022
1 parent 90a0225 commit 26b41c9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
5 changes: 5 additions & 0 deletions transport/http/checksum_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ func (m *contentMD5Checksum) HandleBuild(
stream := req.GetStream()
// compute checksum if payload is explicit
if stream != nil {
if !req.IsStreamSeekable() {
return out, metadata, fmt.Errorf(
"unseekable stream is not supported for computing md5 checksum")
}

v, err := computeMD5Checksum(stream)
if err != nil {
return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err)
Expand Down
3 changes: 2 additions & 1 deletion transport/http/checksum_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestChecksumMiddleware(t *testing.T) {
"nil body": {},
"unseekable payload": {
payload: bytes.NewBuffer([]byte(`xyz`)),
expectError: "error rewinding request stream",
expectError: "unseekable stream is not supported",
},
}

Expand All @@ -61,6 +61,7 @@ func TestChecksumMiddleware(t *testing.T) {
if e, a := c.expectError, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %q, got %v", e, a)
}
return
} else if err != nil {
t.Fatalf("expect no error, got %v", err)
}
Expand Down
40 changes: 29 additions & 11 deletions transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,23 @@ func (r *Request) Clone() *Request {
// to the request and ok set. If the length cannot be determined, an error will
// be returned.
func (r *Request) StreamLength() (size int64, ok bool, err error) {
if r.stream == nil {
return streamLength(r.stream, r.isStreamSeekable, r.streamStartPos)
}

func streamLength(stream io.Reader, seekable bool, startPos int64) (size int64, ok bool, err error) {
if stream == nil {
return 0, true, nil
}

if l, ok := r.stream.(interface{ Len() int }); ok {
if l, ok := stream.(interface{ Len() int }); ok {
return int64(l.Len()), true, nil
}

if !r.isStreamSeekable {
if !seekable {
return 0, false, nil
}

s := r.stream.(io.Seeker)
s := stream.(io.Seeker)
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, false, err
Expand All @@ -69,12 +73,12 @@ func (r *Request) StreamLength() (size int64, ok bool, err error) {
// file, and wants to skip the first N bytes uploading the rest. The
// application would move the file's offset N bytes, then hand it off to
// the SDK to send the remaining. The SDK should respect that initial offset.
_, err = s.Seek(r.streamStartPos, io.SeekStart)
_, err = s.Seek(startPos, io.SeekStart)
if err != nil {
return 0, false, err
}

return endOffset - r.streamStartPos, true, nil
return endOffset - startPos, true, nil
}

// RewindStream will rewind the io.Reader to the relative start position if it
Expand Down Expand Up @@ -103,27 +107,41 @@ func (r *Request) IsStreamSeekable() bool {
return r.isStreamSeekable
}

// SetStream returns a clone of the request with the stream set to the provided reader.
// May return an error if the provided reader is seekable but returns an error.
// SetStream returns a clone of the request with the stream set to the provided
// reader. May return an error if the provided reader is seekable but returns
// an error.
func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) {
rc = r.Clone()

if reader == http.NoBody {
reader = nil
}

var isStreamSeekable bool
var streamStartPos int64
switch v := reader.(type) {
case io.Seeker:
n, err := v.Seek(0, io.SeekCurrent)
if err != nil {
return r, err
}
rc.isStreamSeekable = true
rc.streamStartPos = n
isStreamSeekable = true
streamStartPos = n
default:
rc.isStreamSeekable = false
// If the stream length can be determined, and is determined to be empty,
// use a nil stream to prevent confusion between empty vs not-empty
// streams.
length, ok, err := streamLength(reader, false, 0)
if err != nil {
return nil, err
} else if ok && length == 0 {
reader = nil
}
}

rc.stream = reader
rc.isStreamSeekable = isStreamSeekable
rc.streamStartPos = streamStartPos

return rc, err
}
Expand Down
10 changes: 7 additions & 3 deletions transport/http/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ func TestRequestRewindable(t *testing.T) {
"rewindable": {
Stream: bytes.NewReader([]byte{}),
},
"not rewindable": {
Stream: bytes.NewBuffer([]byte{}),
"empty not rewindable": {
Stream: bytes.NewBuffer([]byte{}),
// ExpectErr: "stream is not seekable",
},
"not empty not rewindable": {
Stream: bytes.NewBuffer([]byte("abc123")),
ExpectErr: "stream is not seekable",
},
"nil stream": {},
Expand Down Expand Up @@ -121,7 +125,7 @@ func TestRequestSetStream(t *testing.T) {
},
"empty unseekable stream": {
reader: bytes.NewBuffer([]byte{}),
expectNilStream: false,
expectNilStream: true,
expectNilBody: true,
},
"empty seekable stream": {
Expand Down

0 comments on commit 26b41c9

Please sign in to comment.