Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix race between client-side stream cancellation and compressed server data arriving #3054

Merged
merged 3 commits into from Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/transport/http2_client.go
Expand Up @@ -352,6 +352,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
ct: t,
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
Expand Down Expand Up @@ -1191,6 +1192,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {

// If headerChan hasn't been closed yet
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
if !endStream {
// HEADERS frame block carries a Response-Headers.
isHeader = true
Expand Down
53 changes: 20 additions & 33 deletions internal/transport/transport.go
Expand Up @@ -233,6 +233,7 @@ const (
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ct *http2Client // nil for server side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
Expand All @@ -251,6 +252,10 @@ type Stream struct {

headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool

// hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex
Expand Down Expand Up @@ -303,34 +308,26 @@ func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}

func (s *Stream) waitOnHeader() error {
func (s *Stream) waitOnHeader() {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
return
}
select {
case <-s.ctx.Done():
// We prefer success over failure when reading messages because we delay
// context error in stream.Read(). To keep behavior consistent, we also
// prefer success here.
select {
case <-s.headerChan:
return nil
default:
}
return ContextErr(s.ctx.Err())
// Close the stream to prevent headers/trailers from changing after
// this function returns.
err := ContextErr(s.ctx.Err())
s.ct.closeStream(s, err, false, 0, status.Convert(err), nil, false)
case <-s.headerChan:
return nil
}
}

// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
s.waitOnHeader()
return s.recvCompress
}

Expand Down Expand Up @@ -358,29 +355,19 @@ func (s *Stream) Header() (metadata.MD, error) {
// header after t.WriteHeader is called.
return s.header.Copy(), nil
}
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
s.waitOnHeader()
if !s.headerValid {
return nil, s.status.Err()
}
return nil, err
return s.header.Copy(), nil
}

// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil. If a context error happens
// first, returns it as a status error. Client-side only.
func (s *Stream) TrailersOnly() (bool, error) {
err := s.waitOnHeader()
if err != nil {
return false, err
}
return s.noHeaders, nil
// before headers are received, returns true, nil. Client-side only.
func (s *Stream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}

// Trailer returns the cached trailer metedata. Note that if it is not called
Expand Down
2 changes: 1 addition & 1 deletion stream.go
Expand Up @@ -488,7 +488,7 @@ func (cs *clientStream) shouldRetry(err error) error {
pushback := 0
hasPushback := false
if cs.attempt.s != nil {
if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil || !to {
if !cs.attempt.s.TrailersOnly() {
return err
}

Expand Down
3 changes: 1 addition & 2 deletions test/context_canceled_test.go
Expand Up @@ -139,13 +139,12 @@ func (s) TestCancelWhileRecvingWithCompression(t *testing.T) {

for i := 0; i < 10; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name))
if err != nil {
t.Fatalf("failed to start bidi streaming RPC: %v", err)
}
// Cancel the stream while receiving to trigger the internal error.
time.AfterFunc(time.Millisecond*10, cancel)
time.AfterFunc(time.Millisecond, cancel)
for {
_, err := s.Recv()
if err != nil {
Expand Down