diff --git a/encoding/encoding.go b/encoding/encoding.go index 30a75da99d5e..0b3fe2c7fac1 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -48,6 +48,14 @@ type Compressor interface { Name() string } +// CompressorSizer is optional, can be implemented to improve efficiency. +// This API is EXPERIMENTAL. +type CompressorSizer interface { + // DecompressedSize returns the exact size the message will + // uncompress into, if known. + DecompressedSize(buf []byte, maxSize int) (int, error) +} + var registeredCompressor = make(map[string]Compressor) // RegisterCompressor registers the compressor with gRPC by its name. It can diff --git a/encoding/gzip/gzip.go b/encoding/gzip/gzip.go index 09564db197fe..265d0c5719dc 100644 --- a/encoding/gzip/gzip.go +++ b/encoding/gzip/gzip.go @@ -23,9 +23,11 @@ package gzip import ( "compress/gzip" + "encoding/binary" "fmt" "io" "io/ioutil" + "math" "sync" "google.golang.org/grpc/encoding" @@ -107,6 +109,20 @@ func (z *reader) Read(p []byte) (n int, err error) { return n, err } +// RFC1952 specifies that the last four bytes "contains the size of +// the original (uncompressed) input data modulo 2^32." +func (c *compressor) DecompressedSize(buf []byte, maxSize int) (int, error) { + if maxSize > math.MaxUint32 { + return 0, fmt.Errorf("grpc: message size not known when messages can be longer than 4GB") + } + last := len(buf) + if last < 4 { + return 0, fmt.Errorf("grpc: invalid gzip buffer") + } + size := binary.LittleEndian.Uint32(buf[last-4 : last]) + return int(size), nil +} + func (c *compressor) Name() string { return Name } diff --git a/rpc_util.go b/rpc_util.go index 088c3f1b2528..86c75a6e22f4 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -648,35 +648,60 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei return nil, st.Err() } + var size int if pf == compressionMade { // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. if dc != nil { d, err = dc.Do(bytes.NewReader(d)) - if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) - } + size = len(d) } else { - dcReader, err := compressor.Decompress(bytes.NewReader(d)) - if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) - } - // Read from LimitReader with limit max+1. So if the underlying - // reader is over limit, the result will be bigger than max. - d, err = ioutil.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) - if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) - } + d, size, err = decompress(compressor, d, maxReceiveMessageSize) + } + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } + } else { + size = len(d) } - if len(d) > maxReceiveMessageSize { + if size > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) + return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", size, maxReceiveMessageSize) } return d, nil } +// Using compressor, decompress d, returning data and size. +// Optionally, if data will be over maxReceiveMessageSize, just return the size. +func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) { + dcReader, err := compressor.Decompress(bytes.NewReader(d)) + if err != nil { + return nil, 0, err + } + if dcFromBytes, ok := compressor.(encoding.CompressorSizer); ok { + if size, err := dcFromBytes.DecompressedSize(d, maxReceiveMessageSize); err == nil { + if size > maxReceiveMessageSize { + return nil, size, nil + } + var buf bytes.Buffer + buf.Grow(size + bytes.MinRead) // extra space guarantees no reallocation + bytesRead, err := buf.ReadFrom(dcReader) + if err != nil { + return nil, size, err + } + if bytesRead != int64(size) { + return nil, size, fmt.Errorf("read different size than expected (%d vs. %d)", bytesRead, size) + } + return buf.Bytes(), size, nil + } + } + // Read from LimitReader with limit max+1. So if the underlying + // reader is over limit, the result will be bigger than max. + d, err = ioutil.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) + return d, len(d), err +} + // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API?