diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5c3e7e8b5..2239336e1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -276,6 +276,14 @@ required-features = ["cancellation"] name = "cancellation-client" path = "src/cancellation/client.rs" +[[bin]] +name = "codec-buffers-server" +path = "src/codec_buffers/server.rs" + +[[bin]] +name = "codec-buffers-client" +path = "src/codec_buffers/client.rs" + [features] gcp = ["dep:prost-types", "tonic/tls"] diff --git a/examples/build.rs b/examples/build.rs index 892b0d96c..454a77214 100644 --- a/examples/build.rs +++ b/examples/build.rs @@ -33,6 +33,14 @@ fn main() { .unwrap(); build_json_codec_service(); + + let smallbuff_copy = out_dir.join("smallbuf"); + let _ = std::fs::create_dir(smallbuff_copy.clone()); // This will panic below if the directory failed to create + tonic_build::configure() + .out_dir(smallbuff_copy) + .codec_path("crate::common::SmallBufferCodec") + .compile(&["proto/helloworld/helloworld.proto"], &["proto"]) + .unwrap(); } // Manually define the json.helloworld.Greeter service which used a custom JsonCodec to use json diff --git a/examples/src/codec_buffers/client.rs b/examples/src/codec_buffers/client.rs new file mode 100644 index 000000000..267e19dbf --- /dev/null +++ b/examples/src/codec_buffers/client.rs @@ -0,0 +1,30 @@ +//! A HelloWorld example that uses a custom codec instead of the default Prost codec. +//! +//! Generated code is the output of codegen as defined in the `examples/build.rs` file. +//! The generation is the one with .codec_path("crate::common::SmallBufferCodec") +//! The generated code assumes that a module `crate::common` exists which defines +//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation. + +pub mod common; + +pub mod small_buf { + include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs")); +} +use small_buf::greeter_client::GreeterClient; + +use crate::small_buf::HelloRequest; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut client = GreeterClient::connect("http://[::1]:50051").await?; + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} diff --git a/examples/src/codec_buffers/common.rs b/examples/src/codec_buffers/common.rs new file mode 100644 index 000000000..c8f9ed777 --- /dev/null +++ b/examples/src/codec_buffers/common.rs @@ -0,0 +1,41 @@ +//! This module defines a common encoder with small buffers. This is useful +//! when you have many concurrent RPC's, and not a huge volume of data per +//! rpc normally. +//! +//! Note that you can customize your codecs per call to the code generator's +//! compile function. This lets you group services by their codec needs. +//! +//! While this codec demonstrates customizing the built-in Prost codec, you +//! can use this to implement other codecs as well, as long as they have a +//! Default implementation. + +use std::marker::PhantomData; + +use prost::Message; +use tonic::codec::{BufferSettings, Codec, ProstCodec}; + +#[derive(Debug, Clone, Copy, Default)] +pub struct SmallBufferCodec(PhantomData<(T, U)>); + +impl Codec for SmallBufferCodec +where + T: Message + Send + 'static, + U: Message + Default + Send + 'static, +{ + type Encode = T; + type Decode = U; + + type Encoder = as Codec>::Encoder; + type Decoder = as Codec>::Decoder; + + fn encoder(&mut self) -> Self::Encoder { + // Here, we will just customize the prost codec's internal buffer settings. + // You can of course implement a complete Codec, Encoder, and Decoder if + // you wish! + ProstCodec::::raw_encoder(BufferSettings::new(512, 4096)) + } + + fn decoder(&mut self) -> Self::Decoder { + ProstCodec::::raw_decoder(BufferSettings::new(512, 4096)) + } +} diff --git a/examples/src/codec_buffers/server.rs b/examples/src/codec_buffers/server.rs new file mode 100644 index 000000000..b30c797d3 --- /dev/null +++ b/examples/src/codec_buffers/server.rs @@ -0,0 +1,51 @@ +//! A HelloWorld example that uses a custom codec instead of the default Prost codec. +//! +//! Generated code is the output of codegen as defined in the `examples/build.rs` file. +//! The generation is the one with .codec_path("crate::common::SmallBufferCodec") +//! The generated code assumes that a module `crate::common` exists which defines +//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation. + +use tonic::{transport::Server, Request, Response, Status}; + +pub mod common; + +pub mod small_buf { + include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs")); +} +use small_buf::{ + greeter_server::{Greeter, GreeterServer}, + HelloReply, HelloRequest, +}; + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50051".parse().unwrap(); + let greeter = MyGreeter::default(); + + println!("GreeterServer listening on {}", addr); + + Server::builder() + .add_service(GreeterServer::new(greeter)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/tonic-build/src/compile_settings.rs b/tonic-build/src/compile_settings.rs new file mode 100644 index 000000000..e2de97910 --- /dev/null +++ b/tonic-build/src/compile_settings.rs @@ -0,0 +1,14 @@ +#[derive(Debug, Clone)] +pub(crate) struct CompileSettings { + #[cfg(feature = "prost")] + pub(crate) codec_path: String, +} + +impl Default for CompileSettings { + fn default() -> Self { + Self { + #[cfg(feature = "prost")] + codec_path: "tonic::codec::ProstCodec".to_string(), + } + } +} diff --git a/tonic-build/src/lib.rs b/tonic-build/src/lib.rs index 015796508..afc9e0ced 100644 --- a/tonic-build/src/lib.rs +++ b/tonic-build/src/lib.rs @@ -97,6 +97,8 @@ pub mod server; mod code_gen; pub use code_gen::CodeGenBuilder; +mod compile_settings; + /// Service generation trait. /// /// This trait can be implemented and consumed diff --git a/tonic-build/src/prost.rs b/tonic-build/src/prost.rs index dafbe9814..1c76e103b 100644 --- a/tonic-build/src/prost.rs +++ b/tonic-build/src/prost.rs @@ -1,4 +1,4 @@ -use crate::code_gen::CodeGenBuilder; +use crate::{code_gen::CodeGenBuilder, compile_settings::CompileSettings}; use super::Attributes; use proc_macro2::TokenStream; @@ -41,6 +41,7 @@ pub fn configure() -> Builder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + compile_settings: CompileSettings::default(), } } @@ -61,61 +62,98 @@ pub fn compile_protos(proto: impl AsRef) -> io::Result<()> { Ok(()) } -const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec"; - /// Non-path Rust types allowed for request/response types. const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"]; -impl crate::Service for Service { - type Method = Method; +/// Newtype wrapper for prost to add tonic-specific extensions +struct TonicBuildService { + prost_service: Service, + methods: Vec, +} + +impl TonicBuildService { + fn new(prost_service: Service, settings: CompileSettings) -> Self { + Self { + // CompileSettings are currently only consumed method-by-method but if you need them in the Service, here's your spot. + // The tonic_build::Service trait specifies that methods are borrowed, so they have to reified up front. + methods: prost_service + .methods + .iter() + .map(|prost_method| TonicBuildMethod { + prost_method: prost_method.clone(), + settings: settings.clone(), + }) + .collect(), + prost_service, + } + } +} + +/// Newtype wrapper for prost to add tonic-specific extensions +struct TonicBuildMethod { + prost_method: Method, + settings: CompileSettings, +} + +impl crate::Service for TonicBuildService { + type Method = TonicBuildMethod; type Comment = String; fn name(&self) -> &str { - &self.name + &self.prost_service.name } fn package(&self) -> &str { - &self.package + &self.prost_service.package } fn identifier(&self) -> &str { - &self.proto_name + &self.prost_service.proto_name } fn comment(&self) -> &[Self::Comment] { - &self.comments.leading[..] + &self.prost_service.comments.leading[..] } fn methods(&self) -> &[Self::Method] { - &self.methods[..] + &self.methods } } -impl crate::Method for Method { +impl crate::Method for TonicBuildMethod { type Comment = String; fn name(&self) -> &str { - &self.name + &self.prost_method.name } fn identifier(&self) -> &str { - &self.proto_name + &self.prost_method.proto_name } + /// For code generation, you can override the codec. + /// + /// You should set the codec path to an import path that has a free + /// function like `fn default()`. The default value is tonic::codec::ProstCodec, + /// which returns a default-configured ProstCodec. You may wish to configure + /// the codec, e.g., with a buffer configuration. + /// + /// Though ProstCodec implements Default, it is currently only required that + /// the function match the Default trait's function spec. fn codec_path(&self) -> &str { - PROST_CODEC_PATH + &self.settings.codec_path } fn client_streaming(&self) -> bool { - self.client_streaming + self.prost_method.client_streaming } fn server_streaming(&self) -> bool { - self.server_streaming + self.prost_method.server_streaming } fn comment(&self) -> &[Self::Comment] { - &self.comments.leading[..] + &self.prost_method.comments.leading[..] } fn request_response_name( @@ -140,8 +178,14 @@ impl crate::Method for Method { } }; - let request = convert_type(&self.input_proto_type, &self.input_type); - let response = convert_type(&self.output_proto_type, &self.output_type); + let request = convert_type( + &self.prost_method.input_proto_type, + &self.prost_method.input_type, + ); + let response = convert_type( + &self.prost_method.output_proto_type, + &self.prost_method.output_type, + ); (request, response) } } @@ -176,7 +220,10 @@ impl prost_build::ServiceGenerator for ServiceGenerator { .disable_comments(self.builder.disable_comments.clone()) .use_arc_self(self.builder.use_arc_self) .generate_default_stubs(self.builder.generate_default_stubs) - .generate_server(&service, &self.builder.proto_path); + .generate_server( + &TonicBuildService::new(service.clone(), self.builder.compile_settings.clone()), + &self.builder.proto_path, + ); self.servers.extend(server); } @@ -188,7 +235,10 @@ impl prost_build::ServiceGenerator for ServiceGenerator { .attributes(self.builder.client_attributes.clone()) .disable_comments(self.builder.disable_comments.clone()) .build_transport(self.builder.build_transport) - .generate_client(&service, &self.builder.proto_path); + .generate_client( + &TonicBuildService::new(service, self.builder.compile_settings.clone()), + &self.builder.proto_path, + ); self.clients.extend(client); } @@ -252,6 +302,7 @@ pub struct Builder { pub(crate) disable_comments: HashSet, pub(crate) use_arc_self: bool, pub(crate) generate_default_stubs: bool, + pub(crate) compile_settings: CompileSettings, out_dir: Option, } @@ -524,6 +575,16 @@ impl Builder { self } + /// Override the default codec. + /// + /// If set, writes `{codec_path}::default()` in generated code wherever a codec is created. + /// + /// This defaults to `"tonic::codec::ProstCodec"` + pub fn codec_path(mut self, codec_path: impl Into) -> Self { + self.compile_settings.codec_path = codec_path.into(); + self + } + /// Compile the .proto files and execute code generation. pub fn compile( self, diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 70d758415..e00b8ca8f 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,4 +1,3 @@ -use super::encode::BUFFER_SIZE; use crate::{metadata::MetadataValue, Status}; use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] @@ -70,6 +69,14 @@ impl EnabledCompressionEncodings { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) struct CompressionSettings { + pub(crate) encoding: CompressionEncoding, + /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. + /// The default buffer growth interval is 8 kilobytes. + pub(crate) buffer_growth_interval: usize, +} + /// The compression encodings Tonic supports. #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] @@ -195,20 +202,22 @@ fn split_by_comma(s: &str) -> impl Iterator { } /// Compress `len` bytes from `decompressed_buf` into `out_buf`. +/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it. #[allow(unused_variables, unreachable_code)] pub(crate) fn compress( - encoding: CompressionEncoding, + settings: CompressionSettings, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { - let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + let buffer_growth_interval = settings.buffer_growth_interval; + let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); - match encoding { + match settings.encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_encoder = GzEncoder::new( @@ -237,19 +246,21 @@ pub(crate) fn compress( /// Decompress `len` bytes from `compressed_buf` into `out_buf`. #[allow(unused_variables, unreachable_code)] pub(crate) fn decompress( - encoding: CompressionEncoding, + settings: CompressionSettings, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { + let buffer_growth_interval = settings.buffer_growth_interval; let estimate_decompressed_len = len * 2; - let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + let capacity = + ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); - match encoding { + match settings.encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index cb88a0649..081f6193d 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,5 +1,5 @@ -use super::compression::{decompress, CompressionEncoding}; -use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; +use super::compression::{decompress, CompressionEncoding, CompressionSettings}; +use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http::StatusCode; @@ -13,8 +13,6 @@ use std::{ use tokio_stream::Stream; use tracing::{debug, trace}; -const BUFFER_SIZE: usize = 8 * 1024; - /// Streaming requests and responses. /// /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface @@ -118,6 +116,7 @@ impl Streaming { B::Error: Into, D: Decoder + Send + 'static, { + let buffer_size = decoder.buffer_settings().buffer_size; Self { decoder: Box::new(decoder), inner: StreamingInner { @@ -127,7 +126,7 @@ impl Streaming { .boxed_unsync(), state: State::ReadHeader, direction, - buf: BytesMut::with_capacity(BUFFER_SIZE), + buf: BytesMut::with_capacity(buffer_size), trailers: None, decompress_buf: BytesMut::new(), encoding, @@ -138,7 +137,10 @@ impl Streaming { } impl StreamingInner { - fn decode_chunk(&mut self) -> Result>, Status> { + fn decode_chunk( + &mut self, + buffer_settings: BufferSettings, + ) -> Result>, Status> { if let State::ReadHeader = self.state { if self.buf.remaining() < HEADER_SIZE { return Ok(None); @@ -205,8 +207,15 @@ impl StreamingInner { let decode_buf = if let Some(encoding) = compression { self.decompress_buf.clear(); - if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) - { + if let Err(err) = decompress( + CompressionSettings { + encoding, + buffer_growth_interval: buffer_settings.buffer_size, + }, + &mut self.buf, + &mut self.decompress_buf, + len, + ) { let message = if let Direction::Response(status) = self.direction { format!( "Error decompressing: {}, while receiving response with status: {}", @@ -364,7 +373,7 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - match self.inner.decode_chunk()? { + match self.inner.decode_chunk(self.decoder.buffer_settings())? { Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { Some(msg) => { self.inner.state = State::ReadHeader; diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 13eb2c96d..396f77399 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,7 @@ -use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; -use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; +use super::compression::{ + compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride, +}; +use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; @@ -11,9 +13,6 @@ use std::{ }; use tokio_stream::{Stream, StreamExt}; -pub(super) const BUFFER_SIZE: usize = 8 * 1024; -const YIELD_THRESHOLD: usize = 32 * 1024; - pub(crate) fn encode_server( encoder: T, source: U, @@ -90,7 +89,8 @@ where compression_override: SingleMessageCompressionOverride, max_message_size: Option, ) -> Self { - let buf = BytesMut::with_capacity(BUFFER_SIZE); + let buffer_settings = encoder.buffer_settings(); + let buf = BytesMut::with_capacity(buffer_settings.buffer_size); let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { @@ -100,7 +100,7 @@ where }; let uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) + BytesMut::with_capacity(buffer_settings.buffer_size) } else { BytesMut::new() }; @@ -132,6 +132,7 @@ where buf, uncompression_buf, } = self.project(); + let buffer_settings = encoder.buffer_settings(); loop { match source.as_mut().poll_next(cx) { @@ -151,12 +152,13 @@ where uncompression_buf, *compression_encoding, *max_message_size, + buffer_settings, item, ) { return Poll::Ready(Some(Err(status))); } - if buf.len() >= YIELD_THRESHOLD { + if buf.len() >= buffer_settings.yield_threshold { return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } } @@ -174,6 +176,7 @@ fn encode_item( uncompression_buf: &mut BytesMut, compression_encoding: Option, max_message_size: Option, + buffer_settings: BufferSettings, item: T::Item, ) -> Result<(), Status> where @@ -195,8 +198,16 @@ where let uncompressed_len = uncompression_buf.len(); - compress(encoding, uncompression_buf, buf, uncompressed_len) - .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + compress( + CompressionSettings { + encoding, + buffer_growth_interval: buffer_settings.buffer_size, + }, + uncompression_buf, + buf, + uncompressed_len, + ) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; } else { encoder .encode(item, &mut EncodeBuf::new(buf)) diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 306621329..998899a6c 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -22,6 +22,75 @@ pub use self::decode::Streaming; #[cfg_attr(docsrs, doc(cfg(feature = "prost")))] pub use self::prost::ProstCodec; +/// Unless overridden, this is the buffer size used for encoding requests. +/// This is spent per-rpc, so you may wish to adjust it. The default is +/// pretty good for most uses, but if you have a ton of concurrent rpcs +/// you may find it too expensive. +const DEFAULT_CODEC_BUFFER_SIZE: usize = 8 * 1024; +const DEFAULT_YIELD_THRESHOLD: usize = 32 * 1024; + +/// Settings for how tonic allocates and grows buffers. +/// +/// Tonic eagerly allocates the buffer_size per RPC, and grows +/// the buffer by buffer_size increments to handle larger messages. +/// Buffer size defaults to 8KiB. +/// +/// Example: +/// ```ignore +/// Buffer start: | 8kb | +/// Message received: | 24612 bytes | +/// Buffer grows: | 8kb | 8kb | 8kb | 8kb | +/// ``` +/// +/// The buffer grows to the next largest buffer_size increment of +/// 32768 to hold 24612 bytes, which is just slightly too large for +/// the previous buffer increment of 24576. +/// +/// If you use a smaller buffer size you will waste less memory, but +/// you will allocate more frequently. If one way or the other matters +/// more to you, you may wish to customize your tonic Codec (see +/// codec_buffers example). +/// +/// Yield threshold is an optimization for streaming rpcs. Sometimes +/// you may have many small messages ready to send. When they are ready, +/// it is a much more efficient use of system resources to batch them +/// together into one larger send(). The yield threshold controls how +/// much you want to bulk up such a batch of ready-to-send messages. +/// The larger your yield threshold the more you will batch - and +/// consequentially allocate contiguous memory, which might be relevant +/// if you're considering large numbers here. +/// If your server streaming rpc does not reach the yield threshold +/// before it reaches Poll::Pending (meaning, it's waiting for more +/// data from wherever you're streaming from) then Tonic will just send +/// along a smaller batch. Yield threshold is an upper-bound, it will +/// not affect the responsiveness of your streaming rpc (for reasonable +/// sizes of yield threshold). +/// Yield threshold defaults to 32 KiB. +#[derive(Clone, Copy, Debug)] +pub struct BufferSettings { + buffer_size: usize, + yield_threshold: usize, +} + +impl BufferSettings { + /// Create a new `BufferSettings` + pub fn new(buffer_size: usize, yield_threshold: usize) -> Self { + Self { + buffer_size, + yield_threshold, + } + } +} + +impl Default for BufferSettings { + fn default() -> Self { + Self { + buffer_size: DEFAULT_CODEC_BUFFER_SIZE, + yield_threshold: DEFAULT_YIELD_THRESHOLD, + } + } +} + // 5 bytes const HEADER_SIZE: usize = // compression flag @@ -63,6 +132,11 @@ pub trait Encoder { /// Encodes a message into the provided buffer. fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error>; + + /// Controls how tonic creates and expands encode buffers. + fn buffer_settings(&self) -> BufferSettings { + BufferSettings::default() + } } /// Decodes gRPC message types @@ -79,4 +153,9 @@ pub trait Decoder { /// is no need to get the length from the bytes, gRPC framing is handled /// for you. fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result, Self::Error>; + + /// Controls how tonic creates and expands decode buffers. + fn buffer_settings(&self) -> BufferSettings { + BufferSettings::default() + } } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index aa872a9ba..217934e9e 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -1,4 +1,4 @@ -use super::{Codec, DecodeBuf, Decoder, Encoder}; +use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder}; use crate::codec::EncodeBuf; use crate::{Code, Status}; use prost::Message; @@ -10,9 +10,41 @@ pub struct ProstCodec { _pd: PhantomData<(T, U)>, } +impl ProstCodec { + /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control + /// how memory is allocated and grows per RPC. + pub fn new() -> Self { + Self { _pd: PhantomData } + } +} + impl Default for ProstCodec { fn default() -> Self { - Self { _pd: PhantomData } + Self::new() + } +} + +impl ProstCodec +where + T: Message + Send + 'static, + U: Message + Default + Send + 'static, +{ + /// A tool for building custom codecs based on prost encoding and decoding. + /// See the codec_buffers example for one possible way to use this. + pub fn raw_encoder(buffer_settings: BufferSettings) -> ::Encoder { + ProstEncoder { + _pd: PhantomData, + buffer_settings, + } + } + + /// A tool for building custom codecs based on prost encoding and decoding. + /// See the codec_buffers example for one possible way to use this. + pub fn raw_decoder(buffer_settings: BufferSettings) -> ::Decoder { + ProstDecoder { + _pd: PhantomData, + buffer_settings, + } } } @@ -28,17 +60,36 @@ where type Decoder = ProstDecoder; fn encoder(&mut self) -> Self::Encoder { - ProstEncoder(PhantomData) + ProstEncoder { + _pd: PhantomData, + buffer_settings: BufferSettings::default(), + } } fn decoder(&mut self) -> Self::Decoder { - ProstDecoder(PhantomData) + ProstDecoder { + _pd: PhantomData, + buffer_settings: BufferSettings::default(), + } } } /// A [`Encoder`] that knows how to encode `T`. #[derive(Debug, Clone, Default)] -pub struct ProstEncoder(PhantomData); +pub struct ProstEncoder { + _pd: PhantomData, + buffer_settings: BufferSettings, +} + +impl ProstEncoder { + /// Get a new encoder with explicit buffer settings + pub fn new(buffer_settings: BufferSettings) -> Self { + Self { + _pd: PhantomData, + buffer_settings, + } + } +} impl Encoder for ProstEncoder { type Item = T; @@ -50,11 +101,28 @@ impl Encoder for ProstEncoder { Ok(()) } + + fn buffer_settings(&self) -> BufferSettings { + self.buffer_settings + } } /// A [`Decoder`] that knows how to decode `U`. #[derive(Debug, Clone, Default)] -pub struct ProstDecoder(PhantomData); +pub struct ProstDecoder { + _pd: PhantomData, + buffer_settings: BufferSettings, +} + +impl ProstDecoder { + /// Get a new decoder with explicit buffer settings + pub fn new(buffer_settings: BufferSettings) -> Self { + Self { + _pd: PhantomData, + buffer_settings, + } + } +} impl Decoder for ProstDecoder { type Item = U; @@ -67,6 +135,10 @@ impl Decoder for ProstDecoder { Ok(item) } + + fn buffer_settings(&self) -> BufferSettings { + self.buffer_settings + } } fn from_decode_error(error: prost::DecodeError) -> crate::Status { @@ -244,6 +316,10 @@ mod tests { buf.put(&item[..]); Ok(()) } + + fn buffer_settings(&self) -> crate::codec::BufferSettings { + Default::default() + } } #[derive(Debug, Clone, Default)] @@ -258,6 +334,10 @@ mod tests { buf.advance(LEN); Ok(Some(out)) } + + fn buffer_settings(&self) -> crate::codec::BufferSettings { + Default::default() + } } mod body {