From dba1390f8e038212df2303b3dea3fe2a8285f6cc Mon Sep 17 00:00:00 2001 From: Andrei Matei Date: Fri, 7 Jan 2022 17:15:05 -0500 Subject: [PATCH] internal/transport: remove duplicate code handling ctx errors transport.ContextErr() was very similar to status.FromContextError(). Besides the code duplication, the latter is arguably better because it handles errors wrapping context errors, and the former only supports raw context errors. This patch does away with transport.ContextErr(). --- clientconn.go | 2 +- internal/transport/http2_client.go | 2 +- internal/transport/http2_server.go | 4 ++-- internal/transport/transport.go | 17 +++-------------- internal/transport/transport_test.go | 4 ++-- status/status.go | 9 +++++++++ stream.go | 2 +- 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/clientconn.go b/clientconn.go index 28f09dc8707..33fa748c971 100644 --- a/clientconn.go +++ b/clientconn.go @@ -578,7 +578,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error { case <-cc.firstResolveEvent.Done(): return nil case <-ctx.Done(): - return status.FromContextError(ctx.Err()).Err() + return status.MustFromContextError(ctx.Err()) case <-cc.ctx.Done(): return ErrClientConnClosing } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f0c72d33710..886a8c2a2c8 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -765,7 +765,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea select { case <-ch: case <-ctx.Done(): - return nil, &NewStreamError{Err: ContextErr(ctx.Err())} + return nil, &NewStreamError{Err: status.MustFromContextError(ctx.Err())} case <-t.goAway: return nil, &NewStreamError{Err: errStreamDrain} case <-t.ctx.Done(): diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 2c6eaf0e59c..55b5db94223 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -1072,7 +1072,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e return ErrConnClosing default: } - return ContextErr(s.ctx.Err()) + return status.MustFromContextError(s.ctx.Err()) } } df := &dataFrame{ @@ -1087,7 +1087,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e return ErrConnClosing default: } - return ContextErr(s.ctx.Err()) + return status.MustFromContextError(s.ctx.Err()) } return t.controlBuf.put(df) } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index d3bf65b2bdf..eff282e0457 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -177,7 +177,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { func (r *recvBufferReader) read(p []byte) (n int, err error) { select { case <-r.ctxDone: - return 0, ContextErr(r.ctx.Err()) + return 0, status.MustFromContextError(r.ctx.Err()) case m := <-r.recv.get(): return r.readAdditional(m, p) } @@ -202,7 +202,7 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) { // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.closeStream(status.MustFromContextError(r.ctx.Err())) m := <-r.recv.get() return r.readAdditional(m, p) case m := <-r.recv.get(): @@ -324,7 +324,7 @@ func (s *Stream) waitOnHeader() { case <-s.ctx.Done(): // Close the stream to prevent headers/trailers from changing after // this function returns. - s.ct.CloseStream(s, ContextErr(s.ctx.Err())) + s.ct.CloseStream(s, status.MustFromContextError(s.ctx.Err())) // headerChan could possibly not be closed yet if closeStream raced // with operateHeaders; wait until it is closed explicitly here. <-s.headerChan @@ -793,14 +793,3 @@ type channelzData struct { lastMsgSentTime int64 lastMsgRecvTime int64 } - -// ContextErr converts the error from context package into a status error. -func ContextErr(err error) error { - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - } - return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) -} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index ec864afb6e9..e657fcb72b8 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -818,7 +818,7 @@ func (s) TestLargeMessageSuspension(t *testing.T) { // stream.go to keep track of context timeout and call CloseStream. go func() { <-ctx.Done() - ct.CloseStream(s, ContextErr(ctx.Err())) + ct.CloseStream(s, status.MustFromContextError(ctx.Err())) }() // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) @@ -1426,7 +1426,7 @@ func (s) TestContextErr(t *testing.T) { {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, } { - err := ContextErr(test.errIn) + err := status.MustFromContextError(test.errIn) if err.Error() != test.errOut.Error() { t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } diff --git a/status/status.go b/status/status.go index 6d163b6e384..7a0e324f745 100644 --- a/status/status.go +++ b/status/status.go @@ -133,3 +133,12 @@ func FromContextError(err error) *Status { } return New(codes.Unknown, err.Error()) } + +// MustFromContextError is like FromContextError, except that it expects err to +// be non-nil, and it returns the status in form of an error. +func MustFromContextError(err error) error { + if err == nil { + return status.New(codes.Internal, "Expected non-nil context error").Err() + } + return FromContextError(err).Err() +} diff --git a/stream.go b/stream.go index 625d47b34e5..a59e7c34e28 100644 --- a/stream.go +++ b/stream.go @@ -638,7 +638,7 @@ func (cs *clientStream) shouldRetry(err error) (bool, error) { return false, nil case <-cs.ctx.Done(): t.Stop() - return false, status.FromContextError(cs.ctx.Err()).Err() + return false, status.MustFromContextError(cs.ctx.Err()) } }