Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: derive Copy trait for messages where possible #950

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions prost-build/src/code_generator.rs
Expand Up @@ -190,7 +190,12 @@ impl<'a> CodeGenerator<'a> {
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Message)]\n",
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
},
prost_path(self.config)
));
self.append_skip_debug(&fq_message_name);
Expand Down Expand Up @@ -595,8 +600,14 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");

let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| {
self.message_graph
.can_field_derive_copy(fq_message_name, field)
});
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Oneof)]\n",
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
prost_path(self.config)
));
self.append_skip_debug(fq_message_name);
Expand Down
Expand Up @@ -23,12 +23,12 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {
}
Expand Up @@ -23,11 +23,11 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {}
59 changes: 55 additions & 4 deletions prost-build/src/message_graph.rs
Expand Up @@ -4,14 +4,18 @@ use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;

use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
};

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
messages: HashMap<String, DescriptorProto>,
}

impl MessageGraph {
Expand All @@ -21,6 +25,7 @@ impl MessageGraph {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
messages: HashMap::new(),
};

for file in files {
Expand All @@ -41,6 +46,7 @@ impl MessageGraph {
let MessageGraph {
ref mut index,
ref mut graph,
..
} = *self;
assert_eq!(b'.', msg_name.as_bytes()[0]);
*index
Expand All @@ -58,13 +64,12 @@ impl MessageGraph {
let msg_index = self.get_or_insert_index(msg_name.clone());

for field in &msg.field {
if field.r#type() == field_descriptor_proto::Type::Message
&& field.label() != field_descriptor_proto::Label::Repeated
{
if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
}
}
self.messages.insert(msg_name.clone(), msg.clone());

for msg in &msg.nested_type {
self.add_message(&msg_name, msg);
Expand All @@ -84,4 +89,50 @@ impl MessageGraph {

has_path_connecting(&self.graph, outer, inner, None)
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Copy trait.
pub fn can_field_derive_copy(
&self,
fq_message_name: &str,
field: &FieldDescriptorProto,
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
if self.is_nested(field.type_name(), fq_message_name) {
false
} else {
self.can_message_derive_copy(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Float
| Type::Double
| Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
)
}
}
}
2 changes: 1 addition & 1 deletion prost-types/src/datetime.rs
Expand Up @@ -614,7 +614,7 @@ mod tests {
};
assert_eq!(
expected,
format!("{}", DateTime::from(timestamp.clone())),
format!("{}", DateTime::from(timestamp)),
"timestamp: {:?}",
timestamp
);
Expand Down
6 changes: 3 additions & 3 deletions prost-types/src/duration.rs
Expand Up @@ -105,7 +105,7 @@ impl TryFrom<Duration> for time::Duration {

impl fmt::Display for Duration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = self.clone();
let mut d = *self;
d.normalize();
if self.seconds < 0 && self.nanos < 0 {
write!(f, "-")?;
Expand Down Expand Up @@ -193,7 +193,7 @@ mod tests {
Ok(duration) => duration,
Err(_) => return Err(TestCaseError::reject("duration out of range")),
};
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);

if std_duration != time::Duration::default() {
let neg_prost_duration = Duration {
Expand All @@ -220,7 +220,7 @@ mod tests {
Ok(duration) => duration,
Err(_) => return Err(TestCaseError::reject("duration out of range")),
};
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);

if std_duration != time::Duration::default() {
let neg_prost_duration = Duration {
Expand Down
8 changes: 4 additions & 4 deletions prost-types/src/protobuf.rs
Expand Up @@ -94,7 +94,7 @@ pub mod descriptor_proto {
/// fields or extension ranges in the same message. Reserved ranges may
/// not overlap.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct ReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -360,7 +360,7 @@ pub mod enum_descriptor_proto {
/// is inclusive such that it can appropriately represent the entire int32
/// domain.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct EnumReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -1853,7 +1853,7 @@ pub struct Mixin {
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
/// microsecond should be expressed in JSON format as "3.000001s".
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Duration {
/// Signed seconds of the span of time. Must be from -315,576,000,000
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
Expand Down Expand Up @@ -2293,7 +2293,7 @@ impl NullValue {
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
Expand Down
11 changes: 5 additions & 6 deletions prost-types/src/timestamp.rs
Expand Up @@ -50,7 +50,7 @@ impl Timestamp {
///
/// [1]: https://github.com/google/protobuf/blob/v3.3.2/src/google/protobuf/util/time_util.cc#L59-L77
pub fn try_normalize(mut self) -> Result<Timestamp, Timestamp> {
let before = self.clone();
let before = self;
self.normalize();
// If the seconds value has changed, and is either i64::MIN or i64::MAX, then the timestamp
// normalization overflowed.
Expand Down Expand Up @@ -201,7 +201,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
type Error = TimestampError;

fn try_from(mut timestamp: Timestamp) -> Result<std::time::SystemTime, Self::Error> {
let orig_timestamp = timestamp.clone();
let orig_timestamp = timestamp;
timestamp.normalize();

let system_time = if timestamp.seconds >= 0 {
Expand All @@ -211,8 +211,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
timestamp
.seconds
.checked_neg()
.ok_or_else(|| TimestampError::OutOfSystemRange(timestamp.clone()))?
as u64,
.ok_or(TimestampError::OutOfSystemRange(timestamp))? as u64,
))
};

Expand All @@ -234,7 +233,7 @@ impl FromStr for Timestamp {

impl fmt::Display for Timestamp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
datetime::DateTime::from(self.clone()).fmt(f)
datetime::DateTime::from(*self).fmt(f)
}
}
#[cfg(test)]
Expand Down Expand Up @@ -262,7 +261,7 @@ mod tests {
) {
let mut timestamp = Timestamp { seconds, nanos };
timestamp.normalize();
if let Ok(system_time) = SystemTime::try_from(timestamp.clone()) {
if let Ok(system_time) = SystemTime::try_from(timestamp) {
prop_assert_eq!(Timestamp::from(system_time), timestamp);
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/src/build.rs
Expand Up @@ -91,6 +91,10 @@ fn main() {
.compile_protos(&[src.join("deprecated_field.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("derive_copy.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("default_string_escape.proto")], includes)
.unwrap();
Expand Down
51 changes: 51 additions & 0 deletions tests/src/derive_copy.proto
@@ -0,0 +1,51 @@
syntax = "proto3";

import "google/protobuf/timestamp.proto";

package derive_copy;

message EmptyMsg {}

message IntegerMsg {
int32 field1 = 1;
int64 field2 = 2;
uint32 field3 = 3;
uint64 field4 = 4;
sint32 field5 = 5;
sint64 field6 = 6;
fixed32 field7 = 7;
fixed64 field8 = 8;
sfixed32 field9 = 9;
sfixed64 field10 = 10;
}

message FloatMsg {
double field1 = 1;
float field2 = 2;
}

message BoolMsg { bool field1 = 1; }

enum AnEnum {
A = 0;
B = 1;
};

message EnumMsg { AnEnum field1 = 1; }

message OneOfMsg {
oneof data {
int32 field1 = 1;
int64 field2 = 2;
}
}

message ComposedMsg {
IntegerMsg field1 = 1;
EnumMsg field2 = 2;
OneOfMsg field3 = 3;
}

message WellKnownMsg {
google.protobuf.Timestamp timestamp = 1;
}
21 changes: 21 additions & 0 deletions tests/src/derive_copy.rs
@@ -0,0 +1,21 @@
include!(concat!(env!("OUT_DIR"), "/derive_copy.rs"));

trait TestCopyIsImplemented: Copy {}

impl TestCopyIsImplemented for EmptyMsg {}

impl TestCopyIsImplemented for IntegerMsg {}

impl TestCopyIsImplemented for FloatMsg {}

impl TestCopyIsImplemented for BoolMsg {}

impl TestCopyIsImplemented for AnEnum {}

impl TestCopyIsImplemented for EnumMsg {}

impl TestCopyIsImplemented for OneOfMsg {}

impl TestCopyIsImplemented for ComposedMsg {}

impl TestCopyIsImplemented for WellKnownMsg {}
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Expand Up @@ -37,6 +37,8 @@ mod debug;
#[cfg(test)]
mod deprecated_field;
#[cfg(test)]
mod derive_copy;
#[cfg(test)]
mod enum_keyword_variant;
#[cfg(test)]
mod generic_derive;
Expand Down