Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with using bytes::Bytes to back bytes and string fields #190

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -37,7 +37,7 @@ default = ["prost-derive"]

[dependencies]
byteorder = "1"
bytes = "0.4.7"
bytes = "0.4"
prost-derive = { version = "0.5.0", path = "prost-derive", optional = true }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -102,7 +102,7 @@ Scalar value types are converted as follows:
| `sfixed64` | `i64` |
| `bool` | `bool` |
| `string` | `String` |
| `bytes` | `Vec<u8>` |
| `bytes` | `Bytes` |

#### Enumerations

Expand Down
5 changes: 3 additions & 2 deletions benches/benchmark.rs
Expand Up @@ -2,6 +2,7 @@ use std::fs::File;
use std::io::Read;
use std::result;

use bytes::Bytes;
use criterion::{Benchmark, Criterion, Throughput};
use failure::bail;
use prost::Message;
Expand All @@ -13,7 +14,7 @@ fn benchmark_dataset<M>(criterion: &mut Criterion, dataset: BenchmarkDataset) ->
where
M: prost::Message + Default + 'static,
{
let payload_len = dataset.payload.iter().map(Vec::len).sum::<usize>();
let payload_len = dataset.payload.iter().map(Bytes::len).sum::<usize>();

let messages = dataset
.payload
Expand Down Expand Up @@ -82,7 +83,7 @@ fn main() -> Result {
protobuf::benchmarks::BenchmarkDataset::decode(buf)?
};

match dataset.message_name.as_str() {
match dataset.message_name.as_ref() {
"benchmarks.proto2.GoogleMessage1" => {
benchmark_dataset::<proto2::GoogleMessage1>(&mut criterion, dataset)?
}
Expand Down
2 changes: 1 addition & 1 deletion conformance/Cargo.toml
Expand Up @@ -6,7 +6,7 @@ publish = false
edition = "2018"

[dependencies]
bytes = "0.4.7"
bytes = "0.4"
env_logger = { version = "0.6", default-features = false }
log = "0.3"
prost = { path = ".." }
Expand Down
31 changes: 15 additions & 16 deletions conformance/src/main.rs
Expand Up @@ -29,7 +29,7 @@ fn main() {

let result = match ConformanceRequest::decode(&bytes) {
Ok(request) => handle_request(request),
Err(error) => conformance_response::Result::ParseError(format!("{:?}", error)),
Err(error) => conformance_response::Result::ParseError(format!("{:?}", error).into()),
};

let mut response = ConformanceResponse::default();
Expand All @@ -52,42 +52,42 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result {
match request.requested_output_format() {
WireFormat::Unspecified => {
return conformance_response::Result::ParseError(
"output format unspecified".to_string(),
"output format unspecified".to_string().into(),
);
}
WireFormat::Json => {
return conformance_response::Result::Skipped(
"JSON output is not supported".to_string(),
"JSON output is not supported".to_string().into(),
);
}
WireFormat::Jspb => {
return conformance_response::Result::Skipped(
"JSPB output is not supported".to_string(),
"JSPB output is not supported".to_string().into(),
);
}
WireFormat::TextFormat => {
return conformance_response::Result::Skipped(
"TEXT_FORMAT output is not supported".to_string(),
"TEXT_FORMAT output is not supported".to_string().into(),
);
}
WireFormat::Protobuf => (),
};

let buf = match request.payload {
None => return conformance_response::Result::ParseError("no payload".to_string()),
None => return conformance_response::Result::ParseError("no payload".to_string().into()),
Some(conformance_request::Payload::JsonPayload(_)) => {
return conformance_response::Result::Skipped(
"JSON input is not supported".to_string(),
"JSON input is not supported".to_string().into(),
);
}
Some(conformance_request::Payload::JspbPayload(_)) => {
return conformance_response::Result::Skipped(
"JSON input is not supported".to_string(),
"JSON input is not supported".to_string().into(),
);
}
Some(conformance_request::Payload::TextPayload(_)) => {
return conformance_response::Result::Skipped(
"JSON input is not supported".to_string(),
"JSON input is not supported".to_string().into(),
);
}
Some(conformance_request::Payload::ProtobufPayload(buf)) => buf,
Expand All @@ -97,20 +97,19 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result {
"protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip::<TestAllTypesProto2>(&buf),
"protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip::<TestAllTypesProto3>(&buf),
_ => {
return conformance_response::Result::ParseError(format!(
"unknown message type: {}",
request.message_type
));
return conformance_response::Result::ParseError(
format!("unknown message type: {}", request.message_type).into(),
);
}
};

match roundtrip {
RoundtripResult::Ok(buf) => conformance_response::Result::ProtobufPayload(buf),
RoundtripResult::Ok(buf) => conformance_response::Result::ProtobufPayload(buf.into()),
RoundtripResult::DecodeError(error) => {
conformance_response::Result::ParseError(error.to_string())
conformance_response::Result::ParseError(error.to_string().into())
}
RoundtripResult::Error(error) => {
conformance_response::Result::RuntimeError(error.to_string())
conformance_response::Result::RuntimeError(error.to_string().into())
}
}
}
17 changes: 9 additions & 8 deletions prost-build/src/code_generator.rs
Expand Up @@ -6,6 +6,7 @@ use std::iter;
use itertools::{Either, Itertools};
use log::debug;
use multimap::MultiMap;
use prost::BytesString;
use prost_types::field_descriptor_proto::{Label, Type};
use prost_types::source_code_info::Location;
use prost_types::{
Expand Down Expand Up @@ -56,15 +57,15 @@ impl<'a> CodeGenerator<'a> {
.location
.sort_by_key(|location| location.path.clone());

let syntax = match file.syntax.as_ref().map(String::as_str) {
let syntax = match file.syntax.as_ref().map(BytesString::as_ref) {
None | Some("proto2") => Syntax::Proto2,
Some("proto3") => Syntax::Proto3,
Some(s) => panic!("unknown syntax: {}", s),
};

let mut code_gen = CodeGenerator {
config,
package: file.package.unwrap(),
package: file.package.unwrap().to_string(),
source_info,
syntax,
message_graph,
Expand Down Expand Up @@ -191,7 +192,7 @@ impl<'a> CodeGenerator<'a> {
match field
.type_name
.as_ref()
.and_then(|type_name| map_types.get(type_name))
.and_then(|type_name| map_types.get::<str>(type_name.as_ref()))
{
Some(&(ref key, ref value)) => {
self.append_map_field(&fq_message_name, field, key, value)
Expand Down Expand Up @@ -636,9 +637,9 @@ impl<'a> CodeGenerator<'a> {
let comments = Comments::from_location(self.location());
self.path.pop();

let name = method.name.take().unwrap();
let input_proto_type = method.input_type.take().unwrap();
let output_proto_type = method.output_type.take().unwrap();
let name = method.name.take().unwrap().to_string();
let input_proto_type = method.input_type.take().unwrap().to_string();
let output_proto_type = method.output_type.take().unwrap().to_string();
let input_type = self.resolve_ident(&input_proto_type);
let output_type = self.resolve_ident(&output_proto_type);
let client_streaming = method.client_streaming();
Expand Down Expand Up @@ -713,8 +714,8 @@ impl<'a> CodeGenerator<'a> {
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => String::from("std::string::String"),
Type::Bytes => String::from("std::vec::Vec<u8>"),
Type::String => String::from("::prost::BytesString"),
Type::Bytes => String::from("::bytes::Bytes"),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
}
}
Expand Down
4 changes: 2 additions & 2 deletions prost-build/src/extern_paths.rs
Expand Up @@ -37,7 +37,7 @@ impl ExternPaths {
extern_paths.insert(".google.protobuf.BoolValue".to_string(), "bool".to_string())?;
extern_paths.insert(
".google.protobuf.BytesValue".to_string(),
"::std::vec::Vec<u8>".to_string(),
"::bytes::Bytes".to_string(),
)?;
extern_paths.insert(
".google.protobuf.DoubleValue".to_string(),
Expand All @@ -49,7 +49,7 @@ impl ExternPaths {
extern_paths.insert(".google.protobuf.Int64Value".to_string(), "i64".to_string())?;
extern_paths.insert(
".google.protobuf.StringValue".to_string(),
"::std::string::String".to_string(),
"prost::BytesString".to_string(),
)?;
extern_paths.insert(
".google.protobuf.UInt32Value".to_string(),
Expand Down
16 changes: 9 additions & 7 deletions prost-build/src/message_graph.rs
Expand Up @@ -4,14 +4,15 @@ use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;

use prost::BytesString;
use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
index: HashMap<BytesString, NodeIndex>,
graph: Graph<BytesString, ()>,
}

impl MessageGraph {
Expand All @@ -31,7 +32,7 @@ impl MessageGraph {
msg_graph
}

fn get_or_insert_index(&mut self, msg_name: String) -> NodeIndex {
fn get_or_insert_index(&mut self, msg_name: BytesString) -> NodeIndex {
let MessageGraph {
ref mut index,
ref mut graph,
Expand All @@ -49,7 +50,8 @@ impl MessageGraph {
/// Since repeated messages are already put in a Vec, boxing them isn’t necessary even if the reference is recursive.
fn add_message(&mut self, package: &str, msg: &DescriptorProto) {
let msg_name = format!("{}.{}", package, msg.name.as_ref().unwrap());
let msg_index = self.get_or_insert_index(msg_name.clone());
let k = msg_name.clone().into();
let msg_index = self.get_or_insert_index(k);

for field in &msg.field {
if field.r#type() == field_descriptor_proto::Type::Message
Expand All @@ -66,12 +68,12 @@ impl MessageGraph {
}

/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {
pub fn is_nested(&self, _outer: &str, _inner: &str) -> bool {
let outer = match self.index.get(_outer) {
Some(outer) => *outer,
None => return false,
};
let inner = match self.index.get(inner) {
let inner = match self.index.get(_inner) {
Some(inner) => *inner,
None => return false,
};
Expand Down
55 changes: 35 additions & 20 deletions prost-derive/src/field/scalar.rs
Expand Up @@ -2,7 +2,7 @@ use std::fmt;

use failure::{bail, format_err, Error};
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{
self, parse_str, FloatSuffix, Ident, IntSuffix, Lit, LitByteStr, Meta, MetaList, MetaNameValue,
NestedMeta, Path,
Expand Down Expand Up @@ -124,8 +124,13 @@ impl Field {
match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
let deref_ident = if self.ty == Ty::String {
quote!(&*#ident)
} else {
ident.clone()
};
quote! {
if #ident != #default {
if #deref_ident != #default {
#encode_fn(#tag, &#ident, buf);
}
}
Expand Down Expand Up @@ -177,8 +182,13 @@ impl Field {
match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
let deref_ident = if self.ty == Ty::String {
quote!(&*#ident)
} else {
ident.clone()
};
quote! {
if #ident != #default {
if #deref_ident != #default {
#encoded_len_fn(#tag, &#ident)
} else {
0
Expand All @@ -196,13 +206,13 @@ impl Field {

pub fn clear(&self, ident: TokenStream) -> TokenStream {
match self.kind {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes => quote!(#ident.clear()),
_ => quote!(#ident = #default),
Kind::Plain(ref default) | Kind::Required(ref default) => match self.ty {
Ty::String | Ty::Bytes => quote!(#ident.clear()),
_ => {
let default = default.typed();
quote!(#ident = #default)
}
}
},
Kind::Optional(_) => quote!(#ident = ::std::option::Option::None),
Kind::Repeated | Kind::Packed => quote!(#ident.clear()),
}
Expand Down Expand Up @@ -465,8 +475,8 @@ impl Ty {
// TODO: rename to 'owned_type'.
pub fn rust_type(&self) -> TokenStream {
match *self {
Ty::String => quote!(::std::string::String),
Ty::Bytes => quote!(::std::vec::Vec<u8>),
Ty::String => quote!(::prost::BytesString),
Ty::Bytes => quote!(::bytes::Bytes),
_ => self.rust_ref_type(),
}
}
Expand Down Expand Up @@ -746,24 +756,28 @@ impl DefaultValue {
pub fn owned(&self) -> TokenStream {
match *self {
DefaultValue::String(ref value) if value.is_empty() => {
quote!(::std::string::String::new())
quote!(::prost::BytesString::new())
}
DefaultValue::String(ref value) => quote!(#value.to_owned()),
DefaultValue::Bytes(ref value) if value.is_empty() => quote!(::std::vec::Vec::new()),
DefaultValue::String(ref value) if value.is_empty() => {
quote!(::prost::BytesString::new())
}
DefaultValue::String(ref value) => {
quote!(::prost::BytesString::from(#value.to_owned()))
}
DefaultValue::Bytes(ref value) if value.is_empty() => quote!(::bytes::Bytes::new()),
DefaultValue::Bytes(ref value) => {
let lit = LitByteStr::new(value, Span::call_site());
quote!(#lit.to_owned())
quote!(::bytes::Bytes::from(#lit))
}

ref other => other.typed(),
}
}

pub fn typed(&self) -> TokenStream {
if let DefaultValue::Enumeration(_) = *self {
quote!(#self as i32)
} else {
quote!(#self)
match *self {
DefaultValue::Enumeration(_) => quote!(#self as i32),
_ => quote!(#self),
}
}
}
Expand All @@ -780,7 +794,8 @@ impl ToTokens for DefaultValue {
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::Bytes(ref value) => {
LitByteStr::new(value, Span::call_site()).to_tokens(tokens)
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
}
DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
DefaultValue::Path(ref value) => value.to_tokens(tokens),
Expand Down
2 changes: 1 addition & 1 deletion prost-derive/src/lib.rs
Expand Up @@ -112,7 +112,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
.into_iter()
.map(|tag| quote!(#tag))
.intersperse(quote!(|));
quote!(#(#tags)* => #merge.map_err(|mut error| {
quote!(#(#tags)* => #merge.map_err(|mut error: _prost::DecodeError| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
}),)
Expand Down