From 6f33edba24919f882c62367cfd1e7df576e2e502 Mon Sep 17 00:00:00 2001 From: Isaac Diamond Date: Wed, 6 Apr 2022 14:10:36 -0700 Subject: [PATCH] handle ContextErr inside WriteHeader --- internal/transport/http2_server.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index f64b602b87e..5a22b8c6673 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -21,7 +21,6 @@ package transport import ( "bytes" "context" - "errors" "fmt" "io" "math" @@ -53,10 +52,10 @@ import ( var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. - ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + ErrIllegalHeaderWrite = status.Error(codes.Internal, "transport: the stream is done or WriteHeader was already called") // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. - ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + ErrHeaderListSizeLimitViolation = status.Error(codes.Internal, "transport: trying to send header list size larger than the limit set by peer") ) // serverConnectionCounter counts the number of connections a server has seen @@ -933,9 +932,14 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool { // WriteHeader sends the header metadata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - if s.updateHeaderSent() || s.getState() == streamDone { + if s.getState() == streamDone { + return ContextErr(s.ctx.Err()) + } + + if s.updateHeaderSent() { return ErrIllegalHeaderWrite } + s.hdrMu.Lock() if md.Len() > 0 { if s.header.Len() > 0 { @@ -1062,14 +1066,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { if !s.isHeaderSent() { // Headers haven't been written yet. if err := t.WriteHeader(s, nil); err != nil { - if _, ok := err.(ConnectionError); ok { - return err - } - if s.ctx.Err() != nil { - return ContextErr(s.ctx.Err()) - } - // TODO(mmukhi, dfawley): Make sure this is the right code to return. - return status.Errorf(codes.Internal, "transport: %v", err) + return err } } else { // Writing headers checks for this condition.