diff --git a/clientconn.go b/clientconn.go index 28f09dc8707..33fa748c971 100644 --- a/clientconn.go +++ b/clientconn.go @@ -578,7 +578,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error { case <-cc.firstResolveEvent.Done(): return nil case <-ctx.Done(): - return status.FromContextError(ctx.Err()).Err() + return status.MustFromContextError(ctx.Err()) case <-cc.ctx.Done(): return ErrClientConnClosing } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f0c72d33710..886a8c2a2c8 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -765,7 +765,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea select { case <-ch: case <-ctx.Done(): - return nil, &NewStreamError{Err: ContextErr(ctx.Err())} + return nil, &NewStreamError{Err: status.MustFromContextError(ctx.Err())} case <-t.goAway: return nil, &NewStreamError{Err: errStreamDrain} case <-t.ctx.Done(): diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 2c6eaf0e59c..55b5db94223 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -1072,7 +1072,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e return ErrConnClosing default: } - return ContextErr(s.ctx.Err()) + return status.MustFromContextError(s.ctx.Err()) } } df := &dataFrame{ @@ -1087,7 +1087,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e return ErrConnClosing default: } - return ContextErr(s.ctx.Err()) + return status.MustFromContextError(s.ctx.Err()) } return t.controlBuf.put(df) } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index d3bf65b2bdf..eff282e0457 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -177,7 +177,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { func (r *recvBufferReader) read(p []byte) (n int, err error) { select { case <-r.ctxDone: - return 0, ContextErr(r.ctx.Err()) + return 0, status.MustFromContextError(r.ctx.Err()) case m := <-r.recv.get(): return r.readAdditional(m, p) } @@ -202,7 +202,7 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) { // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.closeStream(status.MustFromContextError(r.ctx.Err())) m := <-r.recv.get() return r.readAdditional(m, p) case m := <-r.recv.get(): @@ -324,7 +324,7 @@ func (s *Stream) waitOnHeader() { case <-s.ctx.Done(): // Close the stream to prevent headers/trailers from changing after // this function returns. - s.ct.CloseStream(s, ContextErr(s.ctx.Err())) + s.ct.CloseStream(s, status.MustFromContextError(s.ctx.Err())) // headerChan could possibly not be closed yet if closeStream raced // with operateHeaders; wait until it is closed explicitly here. <-s.headerChan @@ -793,14 +793,3 @@ type channelzData struct { lastMsgSentTime int64 lastMsgRecvTime int64 } - -// ContextErr converts the error from context package into a status error. -func ContextErr(err error) error { - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - } - return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) -} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index ec864afb6e9..e657fcb72b8 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -818,7 +818,7 @@ func (s) TestLargeMessageSuspension(t *testing.T) { // stream.go to keep track of context timeout and call CloseStream. go func() { <-ctx.Done() - ct.CloseStream(s, ContextErr(ctx.Err())) + ct.CloseStream(s, status.MustFromContextError(ctx.Err())) }() // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) @@ -1426,7 +1426,7 @@ func (s) TestContextErr(t *testing.T) { {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, } { - err := ContextErr(test.errIn) + err := status.MustFromContextError(test.errIn) if err.Error() != test.errOut.Error() { t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } diff --git a/status/status.go b/status/status.go index 6d163b6e384..7a0e324f745 100644 --- a/status/status.go +++ b/status/status.go @@ -133,3 +133,12 @@ func FromContextError(err error) *Status { } return New(codes.Unknown, err.Error()) } + +// MustFromContextError is like FromContextError, except that it expects err to +// be non-nil, and it returns the status in form of an error. +func MustFromContextError(err error) error { + if err == nil { + return status.New(codes.Internal, "Expected non-nil context error").Err() + } + return FromContextError(err).Err() +} diff --git a/stream.go b/stream.go index 625d47b34e5..a59e7c34e28 100644 --- a/stream.go +++ b/stream.go @@ -638,7 +638,7 @@ func (cs *clientStream) shouldRetry(err error) (bool, error) { return false, nil case <-cs.ctx.Done(): t.Stop() - return false, status.FromContextError(cs.ctx.Err()).Err() + return false, status.MustFromContextError(cs.ctx.Err()) } }