Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better checking of tag duplicates, avoid discarding invalid variant errs #951

Merged
merged 11 commits into from
May 24, 2024
Merged
4 changes: 2 additions & 2 deletions prost-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ proc_macro = true

[dependencies]
anyhow = "1.0.1"
itertools = { version = ">=0.10, <=0.12", default-features = false, features = ["use_alloc"] }
proc-macro2 = "1"
itertools = ">=0.10.1, <=0.12"
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
proc-macro2 = "1.0.60"
quote = "1"
syn = { version = "2", features = ["extra-traits"] }
9 changes: 5 additions & 4 deletions prost-derive/src/field/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes for group field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for group field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/oneof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tags = match tags {
Expand Down
9 changes: 5 additions & 4 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s): #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
164 changes: 126 additions & 38 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ extern crate proc_macro;

use anyhow::{bail, Error};
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::{Span, TokenStream};
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
use quote::quote;
use syn::{
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
Expand All @@ -19,7 +18,7 @@ mod field;
use crate::field::Field;

fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -91,16 +90,18 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
let fields = fields;

let mut tags = fields
if let Some(duplicate_tag) = fields
.iter()
.flat_map(|(_, field)| field.tags())
.collect::<Vec<_>>();
let num_tags = tags.len();
tags.sort_unstable();
tags.dedup();
if tags.len() != num_tags {
bail!("message {} has fields with duplicate tags", ident);
}
.duplicates()
.next()
{
bail!(
"message {} has multiple fields with tag {}",
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
ident,
duplicate_tag
)
};

let encoded_len = fields
.iter()
Expand Down Expand Up @@ -251,16 +252,16 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#methods
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Message, attributes(prost))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_message(input.into()).unwrap().into()
}

fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;
let ident = input.ident;

let generics = &input.generics;
Expand Down Expand Up @@ -359,16 +360,16 @@ fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Enumeration, attributes(prost))]
pub fn enumeration(input: TokenStream) -> TokenStream {
try_enumeration(input).unwrap()
pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_enumeration(input.into()).unwrap().into()
}

fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -412,23 +413,21 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
}

let mut tags = fields
// Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple
// tags.
assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));

if let Some(duplicate_tag) = fields
.iter()
.flat_map(|(variant_ident, field)| -> Result<u32, Error> {
if field.tags().len() > 1 {
bail!(
"invalid oneof variant {}::{}: oneof variants may only have a single tag",
ident,
variant_ident
);
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
}
Ok(field.tags()[0])
})
.collect::<Vec<_>>();
tags.sort_unstable();
tags.dedup();
if tags.len() != fields.len() {
panic!("invalid oneof {}: variants have duplicate tags", ident);
.flat_map(|(_, field)| field.tags())
.duplicates()
.next()
{
bail!(
"invalid oneof {}: multiple variants have tag {}",
ident,
duplicate_tag
);
}

let encode = fields.iter().map(|(variant_ident, field)| {
Expand Down Expand Up @@ -519,10 +518,99 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Oneof, attributes(prost))]
pub fn oneof(input: TokenStream) -> TokenStream {
try_oneof(input).unwrap()
pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_oneof(input.into()).unwrap().into()
}

#[cfg(test)]
mod test {
use crate::{try_message, try_oneof};
use quote::quote;

#[test]
fn test_rejects_colliding_message_fields() {
let output = try_message(quote!(
struct Invalid {
#[prost(bool, tag = "1")]
a: bool,
#[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
b: Option<super::Whatever>,
}
));
assert_eq!(
output
.expect_err("did not reject colliding message fields")
.to_string(),
"message Invalid has multiple fields with tag 1"
);
}

#[test]
fn test_rejects_colliding_oneof_variants() {
let output = try_oneof(quote!(
pub enum Invalid {
#[prost(bool, tag = "1")]
A(bool),
#[prost(bool, tag = "3")]
B(bool),
#[prost(bool, tag = "1")]
C(bool),
}
));
assert_eq!(
output
.expect_err("did not reject colliding oneof variants")
.to_string(),
"invalid oneof Invalid: multiple variants have tag 1"
);
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
fn test_rejects_multiple_tags_oneof_variant() {
let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "1", tag = "2")]
A(bool),
}
));
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 1 and 2"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "3")]
#[prost(tag = "4")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 3 and 4"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tags = "5,6")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"unknown attribute(s): #[prost(tags = \"5,6\")]"
);
}
}