Skip to content

Commit

Permalink
feat: derive Copy trait for messages where possible
Browse files Browse the repository at this point in the history
Rust primitive types can be copied by simply copying the bits. Rust structs can also have this property by deriving the Copy trait.

Automatically derive Copy for:
- messages that only have fields with primitive types
- the Rust enum for one-of fields
- messages whose field type are messages that also implement Copy

Generated code for Protobuf enums already derives Copy.
  • Loading branch information
caspermeijn committed Feb 16, 2024
1 parent 94a3ecf commit 5952163
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 14 deletions.
15 changes: 13 additions & 2 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,12 @@ impl<'a> CodeGenerator<'a> {
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Message)]\n",
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
},
self.config.prost_path.as_deref().unwrap_or("::prost")
));
self.append_skip_debug(&fq_message_name);
Expand Down Expand Up @@ -597,8 +602,14 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");

let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| {
self.message_graph
.can_field_derive_copy(fq_message_name, field)
});
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Oneof)]\n",
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
self.config.prost_path.as_deref().unwrap_or("::prost")
));
self.append_skip_debug(&fq_message_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {}
59 changes: 55 additions & 4 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;

use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
};

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
messages: HashMap<String, DescriptorProto>,
}

impl MessageGraph {
Expand All @@ -21,6 +25,7 @@ impl MessageGraph {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
messages: HashMap::new(),
};

for file in files {
Expand All @@ -41,6 +46,7 @@ impl MessageGraph {
let MessageGraph {
ref mut index,
ref mut graph,
..
} = *self;
assert_eq!(b'.', msg_name.as_bytes()[0]);
*index
Expand All @@ -58,13 +64,12 @@ impl MessageGraph {
let msg_index = self.get_or_insert_index(msg_name.clone());

for field in &msg.field {
if field.r#type() == field_descriptor_proto::Type::Message
&& field.label() != field_descriptor_proto::Label::Repeated
{
if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
}
}
self.messages.insert(msg_name.clone(), msg.clone());

for msg in &msg.nested_type {
self.add_message(&msg_name, msg);
Expand All @@ -84,4 +89,50 @@ impl MessageGraph {

has_path_connecting(&self.graph, outer, inner, None)
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Copy trait.
pub fn can_field_derive_copy(
&self,
fq_message_name: &str,
field: &FieldDescriptorProto,
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
if self.is_nested(field.type_name(), fq_message_name) {
false
} else {
self.can_message_derive_copy(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Float
| Type::Double
| Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
)
}
}
}
8 changes: 4 additions & 4 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub mod descriptor_proto {
/// fields or extension ranges in the same message. Reserved ranges may
/// not overlap.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct ReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -359,7 +359,7 @@ pub mod enum_descriptor_proto {
/// is inclusive such that it can appropriately represent the entire int32
/// domain.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct EnumReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -1852,7 +1852,7 @@ pub struct Mixin {
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
/// microsecond should be expressed in JSON format as "3.000001s".
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Duration {
/// Signed seconds of the span of time. Must be from -315,576,000,000
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
Expand Down Expand Up @@ -2292,7 +2292,7 @@ impl NullValue {
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
Expand Down
4 changes: 4 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ fn main() {
.compile_protos(&[src.join("deprecated_field.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("derive_copy.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("default_string_escape.proto")], includes)
.unwrap();
Expand Down
51 changes: 51 additions & 0 deletions tests/src/derive_copy.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
syntax = "proto3";

import "google/protobuf/timestamp.proto";

package derive_copy;

message EmptyMsg {}

message IntegerMsg {
int32 field1 = 1;
int64 field2 = 2;
uint32 field3 = 3;
uint64 field4 = 4;
sint32 field5 = 5;
sint64 field6 = 6;
fixed32 field7 = 7;
fixed64 field8 = 8;
sfixed32 field9 = 9;
sfixed64 field10 = 10;
}

message FloatMsg {
double field1 = 1;
float field2 = 2;
}

message BoolMsg { bool field1 = 1; }

enum AnEnum {
A = 0;
B = 1;
};

message EnumMsg { AnEnum field1 = 1; }

message OneOfMsg {
oneof data {
int32 field1 = 1;
int64 field2 = 2;
}
}

message ComposedMsg {
IntegerMsg field1 = 1;
EnumMsg field2 = 2;
OneOfMsg field3 = 3;
}

message WellKnownMsg {
google.protobuf.Timestamp timestamp = 1;
}
21 changes: 21 additions & 0 deletions tests/src/derive_copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
include!(concat!(env!("OUT_DIR"), "/derive_copy.rs"));

trait TestCopyIsImplemented: Copy {}

impl TestCopyIsImplemented for EmptyMsg {}

impl TestCopyIsImplemented for IntegerMsg {}

impl TestCopyIsImplemented for FloatMsg {}

impl TestCopyIsImplemented for BoolMsg {}

impl TestCopyIsImplemented for AnEnum {}

impl TestCopyIsImplemented for EnumMsg {}

impl TestCopyIsImplemented for OneOfMsg {}

impl TestCopyIsImplemented for ComposedMsg {}

impl TestCopyIsImplemented for WellKnownMsg {}
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ mod debug;
#[cfg(test)]
mod deprecated_field;
#[cfg(test)]
mod derive_copy;
#[cfg(test)]
mod generic_derive;
#[cfg(test)]
mod message_encoding;
Expand Down

0 comments on commit 5952163

Please sign in to comment.