Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Vec with construct_with #550

Merged
merged 2 commits into from Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
231 changes: 221 additions & 10 deletions pnet_macros/src/decorator.rs
Expand Up @@ -152,7 +152,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
));
}
};
let mut construct_with = Vec::new();
let mut construct_with = None;
let mut is_payload = false;
let mut packet_length = None;
let mut struct_length = None;
Expand Down Expand Up @@ -228,6 +228,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
syn::Meta::List(ref l) => {
if let Some(ident) = l.path.get_ident() {
if ident == "construct_with" {
let mut some_construct_with = Vec::new();
if l.nested.is_empty() {
return Err(Error::new(
l.path.span(),
Expand All @@ -239,7 +240,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
if let syn::NestedMeta::Meta(ref meta) = item {
let ty_str = meta.to_token_stream().to_string();
match make_type(ty_str, false) {
Ok(ty) => construct_with.push(ty),
Ok(ty) => some_construct_with.push(ty),
Err(e) => {
return Err(Error::new(
field.ty.span(),
Expand All @@ -256,6 +257,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
));
}
}
construct_with = Some(some_construct_with);
} else {
return Err(Error::new(
ident.span(),
Expand All @@ -281,7 +283,30 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {

match ty {
Type::Vector(_) => {
struct_length = Some(format!("_packet.{}.len()", field_name).to_owned());
struct_length = if let Some(construct_with) = construct_with.as_ref() {
let mut inner_size = 0;
for arg in construct_with.iter() {
if let Type::Primitive(ref _ty_str, size, _endianness) = *arg {
inner_size += size;
} else {
return Err(Error::new(
field.span(),
"arguments to #[construct_with] must be primitives",
));
}
}
if inner_size % 8 != 0 {
return Err(Error::new(
field.span(),
"types in #[construct_with] for vec must be add up to a multiple of 8 bits",
));
}
inner_size /= 8; // bytes not bits

Some(format!("_packet.{}.len() * {}", field_name, inner_size).to_owned())
} else {
Some(format!("_packet.{}.len()", field_name).to_owned())
};
if !is_payload && packet_length.is_none() {
return Err(Error::new(
field.ty.span(),
Expand All @@ -291,7 +316,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
}
}
Type::Misc(_) => {
if construct_with.is_empty() {
if construct_with.is_none() {
return Err(Error::new(
field.ty.span(),
"non-primitive field types must specify #[construct_with]",
Expand All @@ -308,7 +333,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result<Packet, Error> {
packet_length,
struct_length,
is_payload,
construct_with: Some(construct_with),
construct_with,
});
}

Expand Down Expand Up @@ -459,9 +484,16 @@ fn generate_packet_impl(
[..];
bit_offset += size;
}
Type::Vector(ref inner_ty) => {
handle_vector_field(&field, &mut accessors, &mut mutators, inner_ty, &mut co)?
}
Type::Vector(ref inner_ty) => handle_vector_field(
&field,
&mut bit_offset,
&offset_fns_packet[..],
&mut co,
&name,
&mut mutators,
&mut accessors,
inner_ty,
)?,
Type::Misc(ref ty_str) => handle_misc_field(
&field,
&mut bit_offset,
Expand Down Expand Up @@ -1058,10 +1090,13 @@ fn handle_vec_primitive(

fn handle_vector_field(
field: &Field,
accessors: &mut String,
bit_offset: &mut usize,
offset_fns: &[String],
co: &mut String,
name: &str,
mutators: &mut String,
accessors: &mut String,
inner_ty: &Box<Type>,
co: &mut String,
) -> Result<(), Error> {
if !field.is_payload && !field.packet_length.is_some() {
return Err(Error::new(
Expand Down Expand Up @@ -1118,6 +1153,132 @@ fn handle_vector_field(
));
}
Type::Misc(ref inner_ty_str) => {
if let Some(construct_with) = field.construct_with.as_ref() {
let mut inner_accessors = String::new();
let mut inner_mutators = String::new();
let mut get_args = String::new();
let mut set_args = String::new();
let mut inner_size = 0;
for (i, arg) in construct_with.iter().enumerate() {
if let Type::Primitive(ref ty_str, size, endianness) = *arg {
let mut ops = operations(*bit_offset % 8, size).unwrap();
let target_endianness = if cfg!(target_endian = "little") {
Endianness::Little
} else {
Endianness::Big
};

if endianness == Endianness::Little
|| (target_endianness == Endianness::Little
&& endianness == Endianness::Host)
{
ops = to_little_endian(ops);
}

inner_size += size;
let arg_name = format!("arg{}", i);
inner_accessors = inner_accessors
+ &generate_accessor_with_offset_str(
&arg_name[..],
&ty_str[..],
&co[..],
&ops[..],
&name[..],
)[..];
inner_mutators = inner_mutators
+ &generate_mutator_with_offset_str(
&arg_name[..],
&ty_str[..],
&co[..],
&to_mutator(&ops[..])[..],
&name[..],
)[..];
get_args =
format!("{}get_{}(&self, additional_offset), ", get_args, arg_name);
set_args = format!(
"{}set_{}(_self, vals.{}, additional_offset);\n",
set_args, arg_name, i
);
*bit_offset += size;
// Current offset needs to be recalculated for each arg
*co = current_offset(*bit_offset, offset_fns);
} else {
return Err(Error::new(
field.span,
"arguments to #[construct_with] must be primitives",
));
}
}
if inner_size % 8 != 0 {
return Err(Error::new(
field.span,
"types in #[construct_with] for vec must be add up to a multiple of 8 bits",
));
}
inner_size /= 8; // bytes not bits
*mutators = format!(
"{mutators}
/// Set the value of the {name} field.
#[inline]
#[allow(trivial_numeric_casts)]
#[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))]
pub fn set_{name}(&mut self, vals: &Vec<{inner_ty_str}>) {{
use pnet_macros_support::packet::PrimitiveValues;
let _self = self;
{inner_mutators}
let mut additional_offset = 0;

for val in vals.into_iter() {{
let vals = val.to_primitive_values();

{set_args}

additional_offset += {inner_size};
}}
}}
",
mutators = &mutators[..],
name = field.name,
inner_ty_str = inner_ty_str,
inner_mutators = inner_mutators,
//packet_length = field.packet_length.as_ref().unwrap(),
inner_size = inner_size,
set_args = set_args
);
*accessors = format!(
"{accessors}
/// Get the value of the {name} field
#[inline]
#[allow(trivial_numeric_casts)]
#[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))]
pub fn get_{name}(&self) -> Vec<{inner_ty_str}> {{
let _self = self;
let length = {packet_length};
let vec_length = length.saturating_div({inner_size});
let mut vec = Vec::with_capacity(vec_length);

{inner_accessors}

let mut additional_offset = 0;

for vec_offset in 0..vec_length {{
vec.push({inner_ty_str}::new({get_args}));
additional_offset += {inner_size};
}}

vec
}}
",
accessors = accessors,
name = field.name,
inner_ty_str = inner_ty_str,
inner_accessors = inner_accessors,
packet_length = field.packet_length.as_ref().unwrap(),
inner_size = inner_size,
get_args = &get_args[..get_args.len() - 2]
);
return Ok(());
}
*accessors = format!("{accessors}
/// Get the value of the {name} field (copies contents)
#[inline]
Expand Down Expand Up @@ -1358,6 +1519,31 @@ fn generate_mutator_str(
mutator
}

fn generate_mutator_with_offset_str(
name: &str,
ty: &str,
offset: &str,
operations: &[SetOperation],
inner: &str,
) -> String {
let op_strings = generate_sop_strings(operations);

format!(
"#[inline]
#[allow(trivial_numeric_casts)]
#[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))]
fn set_{name}(_self: &mut {struct_name}, val: {ty}, offset: usize) {{
let co = {co} + offset;
{operations}
}}",
struct_name = inner,
name = name,
ty = ty,
co = offset,
operations = op_strings
)
}

/// Used to turn something like a u16be into
/// "let b0 = ((_self.packet[co + 0] as u16be) << 8) as u16be;
/// let b1 = ((_self.packet[co + 1] as u16be) as u16be;
Expand Down Expand Up @@ -1475,6 +1661,31 @@ fn generate_accessor_str(
accessor
}

fn generate_accessor_with_offset_str(
name: &str,
ty: &str,
offset: &str,
operations: &[GetOperation],
inner: &str,
) -> String {
let op_strings = generate_accessor_op_str("_self.packet", ty, operations);

format!(
"#[inline(always)]
#[allow(trivial_numeric_casts, unused_parens)]
#[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))]
fn get_{name}(_self: &{struct_name}, offset: usize) -> {ty} {{
let co = {co} + offset;
{operations}
}}",
struct_name = inner,
name = name,
ty = ty,
co = offset,
operations = op_strings
)
}

fn current_offset(bit_offset: usize, offset_fns: &[String]) -> String {
let base_offset = bit_offset / 8;

Expand Down
60 changes: 60 additions & 0 deletions pnet_macros/tests/run-pass/vec_construct.rs
@@ -0,0 +1,60 @@
extern crate pnet_macros;
extern crate pnet_macros_support;
use pnet_macros::packet;
use pnet_macros_support::packet::PrimitiveValues;

#[packet]
pub struct PacketWithVecConstruct {
banana: u8,
#[length_fn = "length_fn"]
#[construct_with(u64, u64)]
tomatoes: Vec<Identity>,
#[payload]
payload: Vec<u8>,
}

fn length_fn(_: &PacketWithVecConstructPacket) -> usize {
48
}

#[derive(PartialEq, PartialOrd, Eq, Ord, Clone, Copy, Debug)]
pub struct Identity(pub(crate) [u8; Identity::LEN]);

impl Identity {
const LEN: usize = 16;

pub fn new(b0: u64, b1: u64) -> Identity {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&b0.to_be_bytes());
buf[8..16].copy_from_slice(&b1.to_be_bytes());
Identity(buf)
}
}

impl PrimitiveValues for Identity {
type T = (u64, u64);
fn to_primitive_values(&self) -> (u64, u64) {
(
u64::from_be_bytes(self.0[0..8].try_into().unwrap()),
u64::from_be_bytes(self.0[8..16].try_into().unwrap()),
)
}
}

fn main() {
let test = PacketWithVecConstruct {
banana: 1,
tomatoes: vec![
Identity([2u8; 16]),
Identity([3u8; 16]),
Identity([4u8; 16])
],
payload: vec![],
};

let mut buf = vec![0; PacketWithVecConstructPacket::packet_size(&test)];
let mut packet = MutablePacketWithVecConstructPacket::new(&mut buf).unwrap();
packet.populate(&test);
assert_eq!(packet.get_banana(), test.banana);
assert_eq!(packet.get_tomatoes(), test.tomatoes);
}