diff --git a/internal/grpcutil/compressor.go b/internal/grpcutil/compressor.go index bab642081194..9f4090967980 100644 --- a/internal/grpcutil/compressor.go +++ b/internal/grpcutil/compressor.go @@ -19,7 +19,6 @@ package grpcutil import ( - "fmt" "strings" "google.golang.org/grpc/internal/envconfig" @@ -46,41 +45,3 @@ func RegisteredCompressors() string { } return strings.Join(RegisteredCompressorNames, ",") } - -// ValidateSendCompressor returns an error when given compressor name cannot be -// handled by the server or the client based on the advertised compressors. -func ValidateSendCompressor(name, clientAdvertisedCompressors string) error { - if name == "identity" { - return nil - } - - if !IsCompressorNameRegistered(name) { - return fmt.Errorf("compressor not registered: %s", name) - } - - if !compressorExists(name, clientAdvertisedCompressors) { - return fmt.Errorf("client does not support compressor: %s", name) - } - - return nil -} - -// compressorExists returns true when the given name exists in the comma -// separated compressor list. -func compressorExists(name, compressors string) bool { - var ( - i = 0 - length = len(compressors) - ) - for j := 0; j <= length; j++ { - if j < length && compressors[j] != ',' { - continue - } - - if compressors[i:j] == name { - return true - } - i = j + 1 - } - return false -} diff --git a/internal/grpcutil/compressor_test.go b/internal/grpcutil/compressor_test.go index 6c080976ed8c..0d639422a9a0 100644 --- a/internal/grpcutil/compressor_test.go +++ b/internal/grpcutil/compressor_test.go @@ -19,7 +19,6 @@ package grpcutil import ( - "fmt" "testing" "google.golang.org/grpc/internal/envconfig" @@ -45,44 +44,3 @@ func TestRegisteredCompressors(t *testing.T) { } } } - -func TestValidateSendCompressors(t *testing.T) { - defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames) - RegisteredCompressorNames = []string{"gzip", "snappy"} - tests := []struct { - desc string - name string - advertisedCompressors string - wantErr error - }{ - { - desc: "success_when_identity_compressor", - name: "identity", - advertisedCompressors: "gzip,snappy", - }, - { - desc: "success_when_compressor_exists", - name: "snappy", - advertisedCompressors: "testcomp,gzip,snappy", - }, - { - desc: "failure_when_compressor_not_registered", - name: "testcomp", - advertisedCompressors: "testcomp,gzip,snappy", - wantErr: fmt.Errorf("compressor not registered: testcomp"), - }, - { - desc: "failure_when_compressor_not_advertised", - name: "gzip", - advertisedCompressors: "testcomp,snappy", - wantErr: fmt.Errorf("client does not support compressor: gzip"), - }, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { - t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr) - } - }) - } -} diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index df89d511fe91..507a589e49f3 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -27,7 +27,6 @@ import ( "net" "net/http" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -404,6 +403,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( mdata[hf.Name] = append(mdata[hf.Name], hf.Value) s.contentSubtype = contentSubtype isGRPC = true + + case "grpc-accept-encoding": + s.clientAdvertisedCompressors = hf.Value + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) case "grpc-encoding": s.recvCompress = hf.Value case ":method": @@ -457,10 +460,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return false } - if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 { - s.clientAdvertisedCompressors = strings.Join(encodings, ",") - } - if !isGRPC || headerError { t.controlBuf.put(&cleanupStream{ streamID: streamID, diff --git a/internal/transport/transport.go b/internal/transport/transport.go index ade0ff7049bb..f45cfa899b35 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -344,12 +344,12 @@ func (s *Stream) RecvCompress() string { } // SetSendCompress sets the compression algorithm to the stream. -func (s *Stream) SetSendCompress(str string) error { +func (s *Stream) SetSendCompress(name string) error { if s.isHeaderSent() || s.getState() == streamDone { return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") } - s.sendCompress = str + s.sendCompress = name return nil } @@ -358,8 +358,8 @@ func (s *Stream) SendCompress() string { return s.sendCompress } -// ClientAdvertisedCompressors returns the advertised compressor names by the -// client. +// ClientAdvertisedCompressors returns the compressor names advertised by the +// client via :grpc-accept-encoding header. func (s *Stream) ClientAdvertisedCompressors() string { return s.clientAdvertisedCompressors } diff --git a/server.go b/server.go index 13b346876ab9..2d0c3c51d884 100644 --- a/server.go +++ b/server.go @@ -1289,17 +1289,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { cp = s.opts.cp - _ = stream.SetSendCompress(cp.Type()) sendCompressorName = cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { - _ = stream.SetSendCompress(rc) sendCompressorName = comp.Name() } } + if sendCompressorName != "" { + // Safe to ignore returned error value as we are guaranteed to succeed here + _ = stream.SetSendCompress(sendCompressorName) + } + var payInfo *payloadInfo if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} @@ -1383,6 +1386,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } opts := &transport.Options{Last: true} + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. if stream.SendCompress() != sendCompressorName { comp = encoding.GetCompressor(stream.SendCompress()) } @@ -1613,17 +1618,20 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { ss.cp = s.opts.cp - _ = stream.SetSendCompress(s.opts.cp.Type()) ss.sendCompressorName = s.opts.cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { - _ = stream.SetSendCompress(rc) ss.sendCompressorName = rc } } + if ss.sendCompressorName != "" { + // Safe to ignore returned error value as we are guaranteed to succeed here + _ = stream.SetSendCompress(ss.sendCompressorName) + } + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) if trInfo != nil { @@ -1953,28 +1961,37 @@ func SendHeader(ctx context.Context, md metadata.MD) error { return nil } -// SetSendCompressor sets the compressor that will be used when sending -// RPC payload back to the client. It may be called at most once, and must not -// be called after any event that causes headers to be sent (see SetHeader for -// a complete list). Provided compressor is used when below conditions are met: +// SetSendCompressor sets a compressor for outbound messages. +// It must not be called after any event that causes headers to be sent +// (see SetHeader for a complete list). Provided compressor is used when below +// conditions are met: // // - compressor is registered via encoding.RegisterCompressor // - compressor name exists in the client advertised compressor names sent in -// grpc-accept-encoding metadata. +// :grpc-accept-encoding header. // // The context provided must be the context passed to the server's handler. // // The error returned is compatible with the status package. However, the // status code will often not match the RPC status as seen by the client // application, and therefore, should not be relied upon for this purpose. +// +// # Experimental +// +// Notice: This type is EXPERIMENTAL and may be changed or removed in a +// later release. func SetSendCompressor(ctx context.Context, name string) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) if !ok || stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } - if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { - return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err) + if !grpcutil.IsCompressorNameRegistered(name) { + return status.Errorf(codes.Internal, "grpc: compressor not registered %s", name) + } + + if !compressorExists(name, stream.ClientAdvertisedCompressors()) { + return status.Errorf(codes.Internal, "grpc: client does not support compressor %s", name) } return stream.SetSendCompress(name) @@ -2014,3 +2031,14 @@ type channelzServer struct { func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { return c.s.channelzMetric() } + +// compressorExists returns true when the given name exists in the comma +// separated client compressors. +func compressorExists(name, clientCompressors string) bool { + for _, clientCompressor := range strings.Split(clientCompressors, ",") { + if clientCompressor == name { + return true + } + } + return false +} diff --git a/stream.go b/stream.go index 01c1066e2e45..d261f30ba07d 100644 --- a/stream.go +++ b/stream.go @@ -1575,6 +1575,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { ss.comp = encoding.GetCompressor(sendCompressorsName) ss.sendCompressorName = sendCompressorsName