Skip to content

Commit

Permalink
server: expose API to set send compressor (#5744)
Browse files Browse the repository at this point in the history
Fixes #5792
  • Loading branch information
jronak committed Jan 31, 2023
1 parent a7058f7 commit 0954097
Show file tree
Hide file tree
Showing 6 changed files with 489 additions and 17 deletions.
28 changes: 17 additions & 11 deletions internal/transport/handler_server_test.go
Expand Up @@ -280,31 +280,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
t.Errorf("stream method = %q; want %q", s.method, want)
}

err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
if err != nil {
if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
t.Error(err)
}
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
if err != nil {

if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}

if err := s.SetSendCompress("gzip"); err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
delete(md, "custom-header")
if err != nil {
if err := s.SendHeader(md); err != nil {
t.Error(err)
}
delete(md, "custom-header")

err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
if err == nil {
if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
t.Error("expected SetHeader call after SendHeader to fail")
}
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
if err == nil {

if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
t.Error("expected second SendHeader call to fail")
}

if err := s.SetSendCompress("snappy"); err == nil {
t.Error("expected second SetSendCompress call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
Expand All @@ -317,6 +322,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},
Expand Down
11 changes: 11 additions & 0 deletions internal/transport/http2_server.go
Expand Up @@ -404,6 +404,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
s.contentSubtype = contentSubtype
isGRPC = true

case "grpc-accept-encoding":
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
if hf.Value == "" {
continue
}
compressors := hf.Value
if s.clientAdvertisedCompressors != "" {
compressors = s.clientAdvertisedCompressors + "," + compressors
}
s.clientAdvertisedCompressors = compressors
case "grpc-encoding":
s.recvCompress = hf.Value
case ":method":
Expand Down
23 changes: 21 additions & 2 deletions internal/transport/transport.go
Expand Up @@ -257,6 +257,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota

// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
Expand Down Expand Up @@ -345,8 +348,24 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = name
return nil
}

// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}

// Done returns a channel which is closed when it receives the final status
Expand Down
100 changes: 96 additions & 4 deletions server.go
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1263,6 +1264,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string

// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
Expand All @@ -1283,12 +1285,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

if sendCompressorName != "" {
if err := stream.SetSendCompress(sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1375,6 +1383,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
Expand Down Expand Up @@ -1597,12 +1610,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

if ss.sendCompressorName != "" {
if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1935,6 +1954,60 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets a compressor for outbound messages from the server.
// It must not be called after any event that causes headers to be sent
// (see ServerStream.SetHeader for the complete list). Provided compressor is
// used when below conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name must exist in the client advertised compressor names
// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to
// get client supported compressor names.
//
// The context provided must be the context passed to the server's handler.
// It must be noted that compressor name encoding.Identity disables the
// outbound compression.
// By default, server messages will be sent using the same compressor with
// which request messages were sent.
//
// It is not safe to call SetSendCompressor concurrently with SendHeader and
// SendMsg.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}

if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return fmt.Errorf("unable to set send compressor: %w", err)
}

return stream.SetSendCompress(name)
}

// ClientSupportedCompressors returns compressor names advertised by the client
// via grpc-accept-encoding header.
//
// The context provided must be the context passed to the server's handler.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
Expand Down Expand Up @@ -1969,3 +2042,22 @@ type channelzServer struct {
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}

// validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error {
if name == encoding.Identity {
return nil
}

if !grpcutil.IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered %q", name)
}

for _, c := range strings.Split(clientCompressors, ",") {
if c == name {
return nil // found match
}
}
return fmt.Errorf("client does not support compressor %q", name)
}
9 changes: 9 additions & 0 deletions stream.go
Expand Up @@ -1511,6 +1511,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor

sendCompressorName string

maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
Expand Down Expand Up @@ -1603,6 +1605,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}

// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
Expand Down

0 comments on commit 0954097

Please sign in to comment.