From d8c668b79abfb4db21ccdfb32e6f101e2920af0f Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Fri, 27 Sep 2019 15:25:36 -0700 Subject: [PATCH 1/3] client: fix race between client-side stream cancellation and compressed server data arriving `transport/Stream.RecvCompress` returns what the header contains, if present, or empty string if a context error occurs. However, it "prefers" the header data even if there is a context error, to prevent a related race. What happens here is: 1. RPC starts. 2. Client cancels RPC. 3. `RecvCompress` tells `ClientStream.Recv` that compression used is "" because of the context error. `as.decomp` is left nil, because there is no compressor to look up in the registry. 4. Server's header and first message hit client. 5. Client sees the header and message and allows grpc's stream to see them. (We only provide context errors if we need to block.) 6. Client performs a successful `Read` on the stream, receiving the gzipped payload, then checks `as.decomp`. 7. We have no decompressor but the payload has a bit set indicating the message is compressed, so this is an error. However, when forming the error string, `RecvCompress` now returns "gzip" because it doesn't need to block to get this from the now-received header. This leads to the confusing message about how "gzip" is not installed even though it is. This change makes `RecvCompress` return an error instead of empty string (which is also a valid response), and makes `ClientStream.Recv` return an error when this happens. This effectively terminates the stream and prevents subsequent operations. This results in 10k/10k passing runs (previously observed failure rate of ~1/100). --- internal/transport/transport.go | 6 +++--- rpc_util.go | 6 +++++- server.go | 8 ++++---- stream.go | 12 ++++++++++-- test/context_canceled_test.go | 3 +-- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 1c1d106709a..7f87bff9a16 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -327,11 +327,11 @@ func (s *Stream) waitOnHeader() error { // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. -func (s *Stream) RecvCompress() string { +func (s *Stream) RecvCompress() (string, error) { if err := s.waitOnHeader(); err != nil { - return "" + return "", err } - return s.recvCompress + return s.recvCompress, nil } // SetSendCompress sets the compression algorithm to the stream. diff --git a/rpc_util.go b/rpc_util.go index 088c3f1b252..8a2ed831ddf 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -644,7 +644,11 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei payInfo.wireLength = len(d) } - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { + rc, err := s.RecvCompress() + if err != nil { + return nil, err + } + if st := checkRecvPayload(pf, rc, compressor != nil || dc != nil); st != nil { return nil, st.Err() } diff --git a/server.go b/server.go index f064b73e555..eb2938cbed5 100644 --- a/server.go +++ b/server.go @@ -926,7 +926,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. - if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + if rc, _ := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { dc = s.opts.dc } else if rc != "" && rc != encoding.Identity { decomp = encoding.GetCompressor(rc) @@ -944,7 +944,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if s.opts.cp != nil { cp = s.opts.cp stream.SetSendCompress(cp.Type()) - } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + } else if rc, _ := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { @@ -1151,7 +1151,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. - if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + if rc, _ := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { ss.dc = s.opts.dc } else if rc != "" && rc != encoding.Identity { ss.decomp = encoding.GetCompressor(rc) @@ -1169,7 +1169,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if s.opts.cp != nil { ss.cp = s.opts.cp stream.SetSendCompress(s.opts.cp.Type()) - } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + } else if rc, _ := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { diff --git a/stream.go b/stream.go index 134a624a15d..203ce5626dc 100644 --- a/stream.go +++ b/stream.go @@ -867,8 +867,12 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } if !a.decompSet { + ct, err := a.s.RecvCompress() + if err != nil { + return err + } // Block until we receive headers containing received message encoding. - if ct := a.s.RecvCompress(); ct != "" && ct != encoding.Identity { + if ct != "" && ct != encoding.Identity { if a.dc == nil || a.dc.Type() != ct { // No configured decompressor, or it does not match the incoming // message encoding; attempt to find a registered compressor that does. @@ -1202,7 +1206,11 @@ func (as *addrConnStream) RecvMsg(m interface{}) (err error) { if !as.decompSet { // Block until we receive headers containing received message encoding. - if ct := as.s.RecvCompress(); ct != "" && ct != encoding.Identity { + ct, err := as.s.RecvCompress() + if err != nil { + return err + } + if ct != "" && ct != encoding.Identity { if as.dc == nil || as.dc.Type() != ct { // No configured decompressor, or it does not match the incoming // message encoding; attempt to find a registered compressor that does. diff --git a/test/context_canceled_test.go b/test/context_canceled_test.go index 9715b5c203e..0705e9a3bcc 100644 --- a/test/context_canceled_test.go +++ b/test/context_canceled_test.go @@ -139,13 +139,12 @@ func (s) TestCancelWhileRecvingWithCompression(t *testing.T) { for i := 0; i < 10; i++ { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name)) if err != nil { t.Fatalf("failed to start bidi streaming RPC: %v", err) } // Cancel the stream while receiving to trigger the internal error. - time.AfterFunc(time.Millisecond*10, cancel) + time.AfterFunc(time.Millisecond*1, cancel) for { _, err := s.Recv() if err != nil { From 9ffa82e07db5eb2582eaa2c88855d36f21e5f461 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 30 Sep 2019 15:53:09 -0700 Subject: [PATCH 2/3] different approach --- internal/transport/http2_client.go | 2 ++ internal/transport/transport.go | 50 ++++++++++++------------------ rpc_util.go | 6 +--- server.go | 8 ++--- stream.go | 12 ++----- 5 files changed, 28 insertions(+), 50 deletions(-) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5922750e81c 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -352,6 +352,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ + ct: t, done: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, @@ -1191,6 +1192,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { // If headerChan hasn't been closed yet if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + s.headerValid = true if !endStream { // HEADERS frame block carries a Response-Headers. isHeader = true diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 7f87bff9a16..28d5abd666a 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -233,6 +233,7 @@ const ( type Stream struct { id uint32 st ServerTransport // nil for client side Stream + ct *http2Client // nil for server side Stream ctx context.Context // the associated context of the stream cancel context.CancelFunc // always nil for client side Stream done chan struct{} // closed at the end of stream to unblock writers. On the client side. @@ -251,6 +252,10 @@ type Stream struct { headerChan chan struct{} // closed to indicate the end of header metadata. headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + headerValid bool // set if headerChan is closed due to valid + // headers. false if headerChan is open or + // if the RPC failed before headers (or + // trailers-only) were received // hdrMu protects header and trailer metadata on the server-side. hdrMu sync.Mutex @@ -303,35 +308,27 @@ func (s *Stream) getState() streamState { return streamState(atomic.LoadUint32((*uint32)(&s.state))) } -func (s *Stream) waitOnHeader() error { +func (s *Stream) waitOnHeader() { if s.headerChan == nil { // On the server headerChan is always nil since a stream originates // only after having received headers. - return nil + return } select { case <-s.ctx.Done(): - // We prefer success over failure when reading messages because we delay - // context error in stream.Read(). To keep behavior consistent, we also - // prefer success here. - select { - case <-s.headerChan: - return nil - default: - } - return ContextErr(s.ctx.Err()) + // Close the stream to prevent headers/trailers from changing after + // this function returns. + err := ContextErr(s.ctx.Err()) + s.ct.closeStream(s, err, false, 0, status.Convert(err), nil, false) case <-s.headerChan: - return nil } } // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. -func (s *Stream) RecvCompress() (string, error) { - if err := s.waitOnHeader(); err != nil { - return "", err - } - return s.recvCompress, nil +func (s *Stream) RecvCompress() string { + s.waitOnHeader() + return s.recvCompress } // SetSendCompress sets the compression algorithm to the stream. @@ -358,17 +355,11 @@ func (s *Stream) Header() (metadata.MD, error) { // header after t.WriteHeader is called. return s.header.Copy(), nil } - err := s.waitOnHeader() - // Even if the stream is closed, header is returned if available. - select { - case <-s.headerChan: - if s.header == nil { - return nil, nil - } - return s.header.Copy(), nil - default: + s.waitOnHeader() + if !s.headerValid { + return nil, s.status.Err() } - return nil, err + return s.header.Copy(), nil } // TrailersOnly blocks until a header or trailers-only frame is received and @@ -376,10 +367,7 @@ func (s *Stream) Header() (metadata.MD, error) { // before headers are received, returns true, nil. If a context error happens // first, returns it as a status error. Client-side only. func (s *Stream) TrailersOnly() (bool, error) { - err := s.waitOnHeader() - if err != nil { - return false, err - } + s.waitOnHeader() return s.noHeaders, nil } diff --git a/rpc_util.go b/rpc_util.go index 8a2ed831ddf..088c3f1b252 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -644,11 +644,7 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei payInfo.wireLength = len(d) } - rc, err := s.RecvCompress() - if err != nil { - return nil, err - } - if st := checkRecvPayload(pf, rc, compressor != nil || dc != nil); st != nil { + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { return nil, st.Err() } diff --git a/server.go b/server.go index eb2938cbed5..f064b73e555 100644 --- a/server.go +++ b/server.go @@ -926,7 +926,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. - if rc, _ := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { dc = s.opts.dc } else if rc != "" && rc != encoding.Identity { decomp = encoding.GetCompressor(rc) @@ -944,7 +944,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if s.opts.cp != nil { cp = s.opts.cp stream.SetSendCompress(cp.Type()) - } else if rc, _ := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { @@ -1151,7 +1151,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. - if rc, _ := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { ss.dc = s.opts.dc } else if rc != "" && rc != encoding.Identity { ss.decomp = encoding.GetCompressor(rc) @@ -1169,7 +1169,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if s.opts.cp != nil { ss.cp = s.opts.cp stream.SetSendCompress(s.opts.cp.Type()) - } else if rc, _ := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { diff --git a/stream.go b/stream.go index 203ce5626dc..134a624a15d 100644 --- a/stream.go +++ b/stream.go @@ -867,12 +867,8 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } if !a.decompSet { - ct, err := a.s.RecvCompress() - if err != nil { - return err - } // Block until we receive headers containing received message encoding. - if ct != "" && ct != encoding.Identity { + if ct := a.s.RecvCompress(); ct != "" && ct != encoding.Identity { if a.dc == nil || a.dc.Type() != ct { // No configured decompressor, or it does not match the incoming // message encoding; attempt to find a registered compressor that does. @@ -1206,11 +1202,7 @@ func (as *addrConnStream) RecvMsg(m interface{}) (err error) { if !as.decompSet { // Block until we receive headers containing received message encoding. - ct, err := as.s.RecvCompress() - if err != nil { - return err - } - if ct != "" && ct != encoding.Identity { + if ct := as.s.RecvCompress(); ct != "" && ct != encoding.Identity { if as.dc == nil || as.dc.Type() != ct { // No configured decompressor, or it does not match the incoming // message encoding; attempt to find a registered compressor that does. From eef7557c2f17fdff7b160dbf8dd259eeb134c0c7 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 1 Oct 2019 10:27:11 -0700 Subject: [PATCH 3/3] review cleanups --- internal/transport/transport.go | 15 +++++++-------- stream.go | 2 +- test/context_canceled_test.go | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 28d5abd666a..965c76f18fa 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -252,10 +252,10 @@ type Stream struct { headerChan chan struct{} // closed to indicate the end of header metadata. headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. - headerValid bool // set if headerChan is closed due to valid - // headers. false if headerChan is open or - // if the RPC failed before headers (or - // trailers-only) were received + // headerValid indicates whether a valid header was received. Only + // meaningful after headerChan is closed (always call waitOnHeader() before + // reading its value). + headerValid bool // hdrMu protects header and trailer metadata on the server-side. hdrMu sync.Mutex @@ -364,11 +364,10 @@ func (s *Stream) Header() (metadata.MD, error) { // TrailersOnly blocks until a header or trailers-only frame is received and // then returns true if the stream was trailers-only. If the stream ends -// before headers are received, returns true, nil. If a context error happens -// first, returns it as a status error. Client-side only. -func (s *Stream) TrailersOnly() (bool, error) { +// before headers are received, returns true, nil. Client-side only. +func (s *Stream) TrailersOnly() bool { s.waitOnHeader() - return s.noHeaders, nil + return s.noHeaders } // Trailer returns the cached trailer metedata. Note that if it is not called diff --git a/stream.go b/stream.go index 134a624a15d..bb99940e36f 100644 --- a/stream.go +++ b/stream.go @@ -488,7 +488,7 @@ func (cs *clientStream) shouldRetry(err error) error { pushback := 0 hasPushback := false if cs.attempt.s != nil { - if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil || !to { + if !cs.attempt.s.TrailersOnly() { return err } diff --git a/test/context_canceled_test.go b/test/context_canceled_test.go index 0705e9a3bcc..781f63f0c04 100644 --- a/test/context_canceled_test.go +++ b/test/context_canceled_test.go @@ -144,7 +144,7 @@ func (s) TestCancelWhileRecvingWithCompression(t *testing.T) { t.Fatalf("failed to start bidi streaming RPC: %v", err) } // Cancel the stream while receiving to trigger the internal error. - time.AfterFunc(time.Millisecond*1, cancel) + time.AfterFunc(time.Millisecond, cancel) for { _, err := s.Recv() if err != nil {