Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow HTTP clients and servers to take over the request stream #3437

Merged
merged 4 commits into from Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
117 changes: 56 additions & 61 deletions http3/body.go
Expand Up @@ -2,13 +2,21 @@ package http3

import (
"context"
"fmt"
"io"
"net"

"github.com/lucas-clemente/quic-go"
)

// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
// * for the server: the http.Request.Body
// * for the client: the http.Response.Body
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
// When a stream is taken over, it's the caller's responsibility to close the stream.
type HTTPStreamer interface {
HTTPStream() Stream
}

type StreamCreator interface {
OpenStream() (quic.Stream, error)
OpenStreamSync(context.Context) (quic.Stream, error)
Expand All @@ -30,92 +38,75 @@ type Hijacker interface {
type body struct {
str quic.Stream

// only set for the http.Response
// The channel is closed when the user is done with this response:
// either when Read() errors, or when Close() is called.
reqDone chan<- struct{}
reqDoneClosed bool
wasHijacked bool // set when HTTPStream is called
}

var (
_ io.ReadCloser = &body{}
_ HTTPStreamer = &body{}
)

onFrameError func()
func newRequestBody(str Stream) *body {
return &body{str: str}
}

func (r *body) HTTPStream() Stream {
r.wasHijacked = true
return r.str
}

func (r *body) wasStreamHijacked() bool {
return r.wasHijacked
}

bytesRemainingInFrame uint64
func (r *body) Read(b []byte) (int, error) {
return r.str.Read(b)
}

var _ io.ReadCloser = &body{}
func (r *body) Close() error {
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
return nil
}

type hijackableBody struct {
body
conn quic.Connection // only needed to implement Hijacker
}

var _ Hijacker = &hijackableBody{}

func newRequestBody(str quic.Stream, onFrameError func()) *body {
return &body{
str: str,
onFrameError: onFrameError,
}
// only set for the http.Response
// The channel is closed when the user is done with this response:
// either when Read() errors, or when Close() is called.
reqDone chan<- struct{}
reqDoneClosed bool
}

func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody {
var (
_ Hijacker = &hijackableBody{}
_ HTTPStreamer = &hijackableBody{}
)

func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
return &hijackableBody{
body: body{
str: str,
onFrameError: onFrameError,
reqDone: done,
str: str,
},
conn: conn,
reqDone: done,
conn: conn,
}
}

func (r *hijackableBody) StreamCreator() StreamCreator {
return r.conn
}

func (r *body) Read(b []byte) (int, error) {
n, err := r.readImpl(b)
func (r *hijackableBody) Read(b []byte) (int, error) {
n, err := r.str.Read(b)
if err != nil {
r.requestDone()
}
return n, err
}

func (r *body) readImpl(b []byte) (int, error) {
if r.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(r.str, nil)
if err != nil {
return 0, err
}
switch f := frame.(type) {
case *headersFrame:
// skip HEADERS frames
continue
case *dataFrame:
r.bytesRemainingInFrame = f.Length
break parseLoop
default:
r.onFrameError()
// parseNextFrame skips over unknown frame types
// Therefore, this condition is only entered when we parsed another known frame type.
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
}
}
}

var n int
var err error
if r.bytesRemainingInFrame < uint64(len(b)) {
n, err = r.str.Read(b[:r.bytesRemainingInFrame])
} else {
n, err = r.str.Read(b)
}
r.bytesRemainingInFrame -= uint64(n)
return n, err
}

func (r *body) requestDone() {
func (r *hijackableBody) requestDone() {
if r.reqDoneClosed || r.reqDone == nil {
return
}
Expand All @@ -127,9 +118,13 @@ func (r *body) StreamID() quic.StreamID {
return r.str.StreamID()
}

func (r *body) Close() error {
func (r *hijackableBody) Close() error {
r.requestDone()
// If the EOF was read, CancelRead() is a no-op.
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
return nil
}

func (r *hijackableBody) HTTPStream() Stream {
return r.str
}
207 changes: 36 additions & 171 deletions http3/body_test.go
@@ -1,189 +1,54 @@
package http3

import (
"bytes"
"fmt"
"io"
"errors"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"

"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

type bodyType uint8

const (
bodyTypeRequest bodyType = iota
bodyTypeResponse
)

func (t bodyType) String() string {
if t == bodyTypeRequest {
return "request"
}
return "response"
}

var _ = Describe("Body", func() {
var (
rb io.ReadCloser
str *mockquic.MockStream
buf *bytes.Buffer
reqDone chan struct{}
errorCbCalled bool
)
var _ = Describe("Response Body", func() {
var reqDone chan struct{}

errorCb := func() { errorCbCalled = true }
BeforeEach(func() { reqDone = make(chan struct{}) })

getDataFrame := func(data []byte) []byte {
b := &bytes.Buffer{}
(&dataFrame{Length: uint64(len(data))}).Write(b)
b.Write(data)
return b.Bytes()
}

BeforeEach(func() {
buf = &bytes.Buffer{}
errorCbCalled = false
It("closes the reqDone channel when Read errors", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error"))
rb := newResponseBody(str, nil, reqDone)
_, err := rb.Read([]byte{0})
Expect(err).To(MatchError("test error"))
Expect(reqDone).To(BeClosed())
})

for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
bodyType := bt

Context(fmt.Sprintf("using a %s body", bodyType), func() {
BeforeEach(func() {
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Write(b)
}).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()

switch bodyType {
case bodyTypeRequest:
rb = newRequestBody(str, errorCb)
case bodyTypeResponse:
reqDone = make(chan struct{})
rb = newResponseBody(str, nil, reqDone, errorCb)
}
})

It("reads DATA frames in a single run", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})

It("reads DATA frames in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 3)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("bar")))
})

It("reads DATA frames into too large buffers", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 10)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b[:n]).To(Equal([]byte("foobar")))
})

It("reads DATA frames into too large buffers, in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 4)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte("foob")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte("ar")))
})

It("reads multiple DATA frames", func() {
buf.Write(getDataFrame([]byte("foo")))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("bar")))
})

It("skips HEADERS frames", func() {
buf.Write(getDataFrame([]byte("foo")))
(&headersFrame{Length: 10}).Write(buf)
buf.Write(make([]byte, 10))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := io.ReadFull(rb, b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})

It("errors when it can't parse the frame", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})

It("errors on unexpected frames, and calls the error callback", func() {
(&settingsFrame{}).Write(buf)
_, err := rb.Read([]byte{0})
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
Expect(errorCbCalled).To(BeTrue())
})

if bodyType == bodyTypeResponse {
It("closes the reqDone channel when Read errors", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
})

It("allows multiple calls to Read, when Read errors", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
_, err = rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
It("allows multiple calls to Read, when Read errors", func() {
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2)
rb := newResponseBody(str, nil, reqDone)
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(reqDone).To(BeClosed())
_, err = rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})

It("closes responses", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
Expect(rb.Close()).To(Succeed())
})
It("closes responses", func() {
str := mockquic.NewMockStream(mockCtrl)
rb := newResponseBody(str, nil, reqDone)
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
Expect(rb.Close()).To(Succeed())
})

It("allows multiple calls to Close", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
Expect(rb.Close()).To(Succeed())
Expect(reqDone).To(BeClosed())
Expect(rb.Close()).To(Succeed())
})
}
})
}
It("allows multiple calls to Close", func() {
str := mockquic.NewMockStream(mockCtrl)
rb := newResponseBody(str, nil, reqDone)
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
Expect(rb.Close()).To(Succeed())
Expect(reqDone).To(BeClosed())
Expect(rb.Close()).To(Succeed())
})
})