Skip to content

Commit

Permalink
prost-build: address comments on reuse of Field
Browse files Browse the repository at this point in the history
Make rust_field into a method computing the name on the fly.
In OneofField, make the vector of fields to have Field members.
Don't play reference renaming tricks with field.descriptor.
  • Loading branch information
mzabaluev committed Apr 28, 2024
1 parent 5333e02 commit 1991a11
Showing 1 changed file with 57 additions and 55 deletions.
112 changes: 57 additions & 55 deletions prost-build/src/code_generator.rs
Expand Up @@ -53,41 +53,41 @@ fn prost_path(config: &Config) -> &str {
}

struct Field {
rust_name: String,
descriptor: FieldDescriptorProto,
path_index: i32,
}

impl Field {
fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self {
Self {
rust_name: to_snake(descriptor.name()),
descriptor,
path_index,
}
}

fn rust_name(&self) -> String {
to_snake(self.descriptor.name())
}
}

struct OneofField {
rust_name: String,
descriptor: OneofDescriptorProto,
fields: Vec<(FieldDescriptorProto, i32)>,
fields: Vec<Field>,
path_index: i32,
}

impl OneofField {
fn new(
descriptor: OneofDescriptorProto,
fields: Vec<(FieldDescriptorProto, i32)>,
path_index: i32,
) -> Self {
fn new(descriptor: OneofDescriptorProto, fields: Vec<Field>, path_index: i32) -> Self {
Self {
rust_name: to_snake(descriptor.name()),
descriptor,
fields,
path_index,
}
}

fn rust_name(&self) -> String {
to_snake(self.descriptor.name())
}
}

impl<'a> CodeGenerator<'a> {
Expand Down Expand Up @@ -205,7 +205,7 @@ impl<'a> CodeGenerator<'a> {

// Split the fields into a vector of the normal fields, and oneof fields.
// Path indexes are preserved so that comments can be retrieved.
type OneofFieldsByIndex = MultiMap<i32, (FieldDescriptorProto, i32)>;
type OneofFieldsByIndex = MultiMap<i32, Field>;
let (fields, mut oneof_map): (Vec<Field>, OneofFieldsByIndex) = message
.field
.into_iter()
Expand All @@ -215,7 +215,7 @@ impl<'a> CodeGenerator<'a> {
if proto.proto3_optional.unwrap_or(false) {
Either::Left(Field::new(proto, idx))
} else if let Some(oneof_index) = proto.oneof_index {
Either::Right((oneof_index, (proto, idx)))
Either::Right((oneof_index, Field::new(proto, idx)))
} else {
Either::Left(Field::new(proto, idx))
}
Expand Down Expand Up @@ -407,33 +407,31 @@ impl<'a> CodeGenerator<'a> {
}

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let rust_name = &field.rust_name;
let field = &field.descriptor;
let type_ = field.r#type();
let repeated = field.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(field);
let optional = self.optional(field);
let ty = self.resolve_type(field, fq_message_name);
let type_ = field.descriptor.r#type();
let repeated = field.descriptor.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let ty = self.resolve_type(&field.descriptor, fq_message_name);

let boxed = !repeated
&& ((type_ == Type::Message || type_ == Type::Group)
&& self
.message_graph
.is_nested(field.type_name(), fq_message_name))
.is_nested(field.descriptor.type_name(), fq_message_name))
|| (self
.config
.boxed
.get_first_field(fq_message_name, field.name())
.get_first_field(fq_message_name, field.descriptor.name())
.is_some());

debug!(
" field: {:?}, type: {:?}, boxed: {}",
field.name(),
field.descriptor.name(),
ty,
boxed
);

self.append_doc(fq_message_name, Some(field.name()));
self.append_doc(fq_message_name, Some(field.descriptor.name()));

if deprecated {
self.push_indent();
Expand All @@ -442,21 +440,21 @@ impl<'a> CodeGenerator<'a> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(field);
let type_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
let bytes_type = self
.config
.bytes_type
.get_first_field(fq_message_name, field.name())
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
}

match field.label() {
match field.descriptor.label() {
Label::Optional => {
if optional {
self.buf.push_str(", optional");
Expand All @@ -465,8 +463,9 @@ impl<'a> CodeGenerator<'a> {
Label::Required => self.buf.push_str(", required"),
Label::Repeated => {
self.buf.push_str(", repeated");
if can_pack(field)
if can_pack(&field.descriptor)
&& !field
.descriptor
.options
.as_ref()
.map_or(self.syntax == Syntax::Proto3, |options| options.packed())
Expand All @@ -480,9 +479,9 @@ impl<'a> CodeGenerator<'a> {
self.buf.push_str(", boxed");
}
self.buf.push_str(", tag=\"");
self.buf.push_str(&field.number().to_string());
self.buf.push_str(&field.descriptor.number().to_string());

if let Some(ref default) = field.default_value {
if let Some(ref default) = field.descriptor.default_value {
self.buf.push_str("\", default=\"");
if type_ == Type::Bytes {
self.buf.push_str("b\\\"");
Expand All @@ -499,6 +498,7 @@ impl<'a> CodeGenerator<'a> {
// the last segment and strip it from the left
// side of the default value.
let enum_type = field
.descriptor
.type_name
.as_ref()
.and_then(|ty| ty.split('.').last())
Expand All @@ -513,10 +513,10 @@ impl<'a> CodeGenerator<'a> {
}

self.buf.push_str("\")]\n");
self.append_field_attributes(fq_message_name, field.name());
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.push_indent();
self.buf.push_str("pub ");
self.buf.push_str(rust_name);
self.buf.push_str(&field.rust_name());
self.buf.push_str(": ");

let prost_path = prost_path(self.config);
Expand Down Expand Up @@ -548,25 +548,23 @@ impl<'a> CodeGenerator<'a> {
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let rust_name = &field.rust_name;
let field = &field.descriptor;
let key_ty = self.resolve_type(key, fq_message_name);
let value_ty = self.resolve_type(value, fq_message_name);

debug!(
" map field: {:?}, key type: {:?}, value type: {:?}",
field.name(),
field.descriptor.name(),
key_ty,
value_ty
);

self.append_doc(fq_message_name, Some(field.name()));
self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.push_indent();

let map_type = self
.config
.map_type
.get_first_field(fq_message_name, field.name())
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
let key_tag = self.field_type_tag(key);
Expand All @@ -577,13 +575,13 @@ impl<'a> CodeGenerator<'a> {
map_type.annotation(),
key_tag,
value_tag,
field.number()
field.descriptor.number()
));
self.append_field_attributes(fq_message_name, field.name());
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: {}<{}, {}>,\n",
rust_name,
field.rust_name(),
map_type.rust_type(),
key_ty,
value_ty
Expand All @@ -609,14 +607,15 @@ impl<'a> CodeGenerator<'a> {
oneof
.fields
.iter()
.map(|(field, _)| field.number())
.map(|field| field.descriptor.number())
.join(", "),
));
self.append_field_attributes(fq_message_name, oneof.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: ::core::option::Option<{}>,\n",
oneof.rust_name, type_name
oneof.rust_name(),
type_name
));
}

Expand Down Expand Up @@ -645,51 +644,54 @@ impl<'a> CodeGenerator<'a> {

self.path.push(2);
self.depth += 1;
for (field, idx) in &oneof.fields {
let type_ = field.r#type();
for field in &oneof.fields {
let type_ = field.descriptor.r#type();

self.path.push(*idx);
self.append_doc(fq_message_name, Some(field.name()));
self.path.push(field.path_index);
self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.path.pop();

self.push_indent();
let ty_tag = self.field_type_tag(field);
let ty_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
field.number()
field.descriptor.number()
));
self.append_field_attributes(&oneof_name, field.name());
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
let ty = self.resolve_type(field, fq_message_name);
let ty = self.resolve_type(&field.descriptor, fq_message_name);

let boxed = ((type_ == Type::Message || type_ == Type::Group)
&& self
.message_graph
.is_nested(field.type_name(), fq_message_name))
.is_nested(field.descriptor.type_name(), fq_message_name))
|| (self
.config
.boxed
.get_first_field(&oneof_name, field.name())
.get_first_field(&oneof_name, field.descriptor.name())
.is_some());

debug!(
" oneof: {:?}, type: {:?}, boxed: {}",
field.name(),
field.descriptor.name(),
ty,
boxed
);

if boxed {
self.buf.push_str(&format!(
"{}(::prost::alloc::boxed::Box<{}>),\n",
to_upper_camel(field.name()),
to_upper_camel(field.descriptor.name()),
ty
));
} else {
self.buf
.push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty));
self.buf.push_str(&format!(
"{}({}),\n",
to_upper_camel(field.descriptor.name()),
ty
));
}
}
self.depth -= 1;
Expand Down

0 comments on commit 1991a11

Please sign in to comment.