Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jronak committed Dec 7, 2022
1 parent b19a201 commit 8be817f
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 101 deletions.
39 changes: 0 additions & 39 deletions internal/grpcutil/compressor.go
Expand Up @@ -19,7 +19,6 @@
package grpcutil

import (
"fmt"
"strings"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -46,41 +45,3 @@ func RegisteredCompressors() string {
}
return strings.Join(RegisteredCompressorNames, ",")
}

// ValidateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func ValidateSendCompressor(name, clientAdvertisedCompressors string) error {
if name == "identity" {
return nil
}

if !IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered: %s", name)
}

if !compressorExists(name, clientAdvertisedCompressors) {
return fmt.Errorf("client does not support compressor: %s", name)
}

return nil
}

// compressorExists returns true when the given name exists in the comma
// separated compressor list.
func compressorExists(name, compressors string) bool {
var (
i = 0
length = len(compressors)
)
for j := 0; j <= length; j++ {
if j < length && compressors[j] != ',' {
continue
}

if compressors[i:j] == name {
return true
}
i = j + 1
}
return false
}
42 changes: 0 additions & 42 deletions internal/grpcutil/compressor_test.go
Expand Up @@ -19,7 +19,6 @@
package grpcutil

import (
"fmt"
"testing"

"google.golang.org/grpc/internal/envconfig"
Expand All @@ -45,44 +44,3 @@ func TestRegisteredCompressors(t *testing.T) {
}
}
}

func TestValidateSendCompressors(t *testing.T) {
defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames)
RegisteredCompressorNames = []string{"gzip", "snappy"}
tests := []struct {
desc string
name string
advertisedCompressors string
wantErr error
}{
{
desc: "success_when_identity_compressor",
name: "identity",
advertisedCompressors: "gzip,snappy",
},
{
desc: "success_when_compressor_exists",
name: "snappy",
advertisedCompressors: "testcomp,gzip,snappy",
},
{
desc: "failure_when_compressor_not_registered",
name: "testcomp",
advertisedCompressors: "testcomp,gzip,snappy",
wantErr: fmt.Errorf("compressor not registered: testcomp"),
},
{
desc: "failure_when_compressor_not_advertised",
name: "gzip",
advertisedCompressors: "testcomp,snappy",
wantErr: fmt.Errorf("client does not support compressor: gzip"),
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) {
t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr)
}
})
}
}
9 changes: 4 additions & 5 deletions internal/transport/http2_server.go
Expand Up @@ -27,7 +27,6 @@ import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -404,6 +403,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
s.contentSubtype = contentSubtype
isGRPC = true

case "grpc-accept-encoding":
s.clientAdvertisedCompressors = hf.Value
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
case "grpc-encoding":
s.recvCompress = hf.Value
case ":method":
Expand Down Expand Up @@ -457,10 +460,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return false
}

if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 {
s.clientAdvertisedCompressors = strings.Join(encodings, ",")
}

if !isGRPC || headerError {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
Expand Down
8 changes: 4 additions & 4 deletions internal/transport/transport.go
Expand Up @@ -344,12 +344,12 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) error {
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = str
s.sendCompress = name
return nil
}

Expand All @@ -358,8 +358,8 @@ func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the advertised compressor names by the
// client.
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via :grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}
Expand Down
50 changes: 39 additions & 11 deletions server.go
Expand Up @@ -1289,17 +1289,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
_ = stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
_ = stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

if sendCompressorName != "" {
// Safe to ignore returned error value as we are guaranteed to succeed here
_ = stream.SetSendCompress(sendCompressorName)
}

var payInfo *payloadInfo
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
Expand Down Expand Up @@ -1383,6 +1386,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
Expand Down Expand Up @@ -1613,17 +1618,20 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
_ = stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
_ = stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

if ss.sendCompressorName != "" {
// Safe to ignore returned error value as we are guaranteed to succeed here
_ = stream.SetSendCompress(ss.sendCompressorName)
}

ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp)

if trInfo != nil {
Expand Down Expand Up @@ -1953,28 +1961,37 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets the compressor that will be used when sending
// RPC payload back to the client. It may be called at most once, and must not
// be called after any event that causes headers to be sent (see SetHeader for
// a complete list). Provided compressor is used when below conditions are met:
// SetSendCompressor sets a compressor for outbound messages.
// It must not be called after any event that causes headers to be sent
// (see SetHeader for a complete list). Provided compressor is used when below
// conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name exists in the client advertised compressor names sent in
// grpc-accept-encoding metadata.
// :grpc-accept-encoding header.
//
// The context provided must be the context passed to the server's handler.
//
// The error returned is compatible with the status package. However, the
// status code will often not match the RPC status as seen by the client
// application, and therefore, should not be relied upon for this purpose.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
}

if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err)
if !grpcutil.IsCompressorNameRegistered(name) {
return status.Errorf(codes.Internal, "grpc: compressor not registered %s", name)
}

if !compressorExists(name, stream.ClientAdvertisedCompressors()) {
return status.Errorf(codes.Internal, "grpc: client does not support compressor %s", name)
}

return stream.SetSendCompress(name)
Expand Down Expand Up @@ -2014,3 +2031,14 @@ type channelzServer struct {
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}

// compressorExists returns true when the given name exists in the comma
// separated client compressors.
func compressorExists(name, clientCompressors string) bool {
for _, clientCompressor := range strings.Split(clientCompressors, ",") {
if clientCompressor == name {
return true
}
}
return false
}
2 changes: 2 additions & 0 deletions stream.go
Expand Up @@ -1575,6 +1575,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
Expand Down

0 comments on commit 8be817f

Please sign in to comment.