diff --git a/pnet_macros/src/decorator.rs b/pnet_macros/src/decorator.rs index 5e69e932..7a628b52 100644 --- a/pnet_macros/src/decorator.rs +++ b/pnet_macros/src/decorator.rs @@ -152,7 +152,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { )); } }; - let mut construct_with = Vec::new(); + let mut construct_with = None; let mut is_payload = false; let mut packet_length = None; let mut struct_length = None; @@ -228,6 +228,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { syn::Meta::List(ref l) => { if let Some(ident) = l.path.get_ident() { if ident == "construct_with" { + let mut some_construct_with = Vec::new(); if l.nested.is_empty() { return Err(Error::new( l.path.span(), @@ -239,7 +240,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { if let syn::NestedMeta::Meta(ref meta) = item { let ty_str = meta.to_token_stream().to_string(); match make_type(ty_str, false) { - Ok(ty) => construct_with.push(ty), + Ok(ty) => some_construct_with.push(ty), Err(e) => { return Err(Error::new( field.ty.span(), @@ -256,6 +257,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { )); } } + construct_with = Some(some_construct_with); } else { return Err(Error::new( ident.span(), @@ -281,7 +283,30 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { match ty { Type::Vector(_) => { - struct_length = Some(format!("_packet.{}.len()", field_name).to_owned()); + struct_length = if let Some(construct_with) = construct_with.as_ref() { + let mut inner_size = 0; + for arg in construct_with.iter() { + if let Type::Primitive(ref _ty_str, size, _endianness) = *arg { + inner_size += size; + } else { + return Err(Error::new( + field.span(), + "arguments to #[construct_with] must be primitives", + )); + } + } + if inner_size % 8 != 0 { + return Err(Error::new( + field.span(), + "types in #[construct_with] for vec must be add up to a multiple of 8 bits", + )); + } + inner_size /= 8; // bytes not bits + + Some(format!("_packet.{}.len() * {}", field_name, inner_size).to_owned()) + } else { + Some(format!("_packet.{}.len()", field_name).to_owned()) + }; if !is_payload && packet_length.is_none() { return Err(Error::new( field.ty.span(), @@ -291,7 +316,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { } } Type::Misc(_) => { - if construct_with.is_empty() { + if construct_with.is_none() { return Err(Error::new( field.ty.span(), "non-primitive field types must specify #[construct_with]", @@ -308,7 +333,7 @@ fn make_packet(s: &syn::DataStruct, name: String) -> Result { packet_length, struct_length, is_payload, - construct_with: Some(construct_with), + construct_with, }); } @@ -459,9 +484,16 @@ fn generate_packet_impl( [..]; bit_offset += size; } - Type::Vector(ref inner_ty) => { - handle_vector_field(&field, &mut accessors, &mut mutators, inner_ty, &mut co)? - } + Type::Vector(ref inner_ty) => handle_vector_field( + &field, + &mut bit_offset, + &offset_fns_packet[..], + &mut co, + &name, + &mut mutators, + &mut accessors, + inner_ty, + )?, Type::Misc(ref ty_str) => handle_misc_field( &field, &mut bit_offset, @@ -1058,10 +1090,13 @@ fn handle_vec_primitive( fn handle_vector_field( field: &Field, - accessors: &mut String, + bit_offset: &mut usize, + offset_fns: &[String], + co: &mut String, + name: &str, mutators: &mut String, + accessors: &mut String, inner_ty: &Box, - co: &mut String, ) -> Result<(), Error> { if !field.is_payload && !field.packet_length.is_some() { return Err(Error::new( @@ -1118,6 +1153,132 @@ fn handle_vector_field( )); } Type::Misc(ref inner_ty_str) => { + if let Some(construct_with) = field.construct_with.as_ref() { + let mut inner_accessors = String::new(); + let mut inner_mutators = String::new(); + let mut get_args = String::new(); + let mut set_args = String::new(); + let mut inner_size = 0; + for (i, arg) in construct_with.iter().enumerate() { + if let Type::Primitive(ref ty_str, size, endianness) = *arg { + let mut ops = operations(*bit_offset % 8, size).unwrap(); + let target_endianness = if cfg!(target_endian = "little") { + Endianness::Little + } else { + Endianness::Big + }; + + if endianness == Endianness::Little + || (target_endianness == Endianness::Little + && endianness == Endianness::Host) + { + ops = to_little_endian(ops); + } + + inner_size += size; + let arg_name = format!("arg{}", i); + inner_accessors = inner_accessors + + &generate_accessor_with_offset_str( + &arg_name[..], + &ty_str[..], + &co[..], + &ops[..], + &name[..], + )[..]; + inner_mutators = inner_mutators + + &generate_mutator_with_offset_str( + &arg_name[..], + &ty_str[..], + &co[..], + &to_mutator(&ops[..])[..], + &name[..], + )[..]; + get_args = + format!("{}get_{}(&self, additional_offset), ", get_args, arg_name); + set_args = format!( + "{}set_{}(_self, vals.{}, additional_offset);\n", + set_args, arg_name, i + ); + *bit_offset += size; + // Current offset needs to be recalculated for each arg + *co = current_offset(*bit_offset, offset_fns); + } else { + return Err(Error::new( + field.span, + "arguments to #[construct_with] must be primitives", + )); + } + } + if inner_size % 8 != 0 { + return Err(Error::new( + field.span, + "types in #[construct_with] for vec must be add up to a multiple of 8 bits", + )); + } + inner_size /= 8; // bytes not bits + *mutators = format!( + "{mutators} + /// Set the value of the {name} field. + #[inline] + #[allow(trivial_numeric_casts)] + #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] + pub fn set_{name}(&mut self, vals: &Vec<{inner_ty_str}>) {{ + use pnet_macros_support::packet::PrimitiveValues; + let _self = self; + {inner_mutators} + let mut additional_offset = 0; + + for val in vals.into_iter() {{ + let vals = val.to_primitive_values(); + + {set_args} + + additional_offset += {inner_size}; + }} + }} + ", + mutators = &mutators[..], + name = field.name, + inner_ty_str = inner_ty_str, + inner_mutators = inner_mutators, + //packet_length = field.packet_length.as_ref().unwrap(), + inner_size = inner_size, + set_args = set_args + ); + *accessors = format!( + "{accessors} + /// Get the value of the {name} field + #[inline] + #[allow(trivial_numeric_casts)] + #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] + pub fn get_{name}(&self) -> Vec<{inner_ty_str}> {{ + let _self = self; + let length = {packet_length}; + let vec_length = length.saturating_div({inner_size}); + let mut vec = Vec::with_capacity(vec_length); + + {inner_accessors} + + let mut additional_offset = 0; + + for vec_offset in 0..vec_length {{ + vec.push({inner_ty_str}::new({get_args})); + additional_offset += {inner_size}; + }} + + vec + }} + ", + accessors = accessors, + name = field.name, + inner_ty_str = inner_ty_str, + inner_accessors = inner_accessors, + packet_length = field.packet_length.as_ref().unwrap(), + inner_size = inner_size, + get_args = &get_args[..get_args.len() - 2] + ); + return Ok(()); + } *accessors = format!("{accessors} /// Get the value of the {name} field (copies contents) #[inline] @@ -1358,6 +1519,31 @@ fn generate_mutator_str( mutator } +fn generate_mutator_with_offset_str( + name: &str, + ty: &str, + offset: &str, + operations: &[SetOperation], + inner: &str, +) -> String { + let op_strings = generate_sop_strings(operations); + + format!( + "#[inline] + #[allow(trivial_numeric_casts)] + #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] + fn set_{name}(_self: &mut {struct_name}, val: {ty}, offset: usize) {{ + let co = {co} + offset; + {operations} + }}", + struct_name = inner, + name = name, + ty = ty, + co = offset, + operations = op_strings + ) +} + /// Used to turn something like a u16be into /// "let b0 = ((_self.packet[co + 0] as u16be) << 8) as u16be; /// let b1 = ((_self.packet[co + 1] as u16be) as u16be; @@ -1475,6 +1661,31 @@ fn generate_accessor_str( accessor } +fn generate_accessor_with_offset_str( + name: &str, + ty: &str, + offset: &str, + operations: &[GetOperation], + inner: &str, +) -> String { + let op_strings = generate_accessor_op_str("_self.packet", ty, operations); + + format!( + "#[inline(always)] + #[allow(trivial_numeric_casts, unused_parens)] + #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] + fn get_{name}(_self: &{struct_name}, offset: usize) -> {ty} {{ + let co = {co} + offset; + {operations} + }}", + struct_name = inner, + name = name, + ty = ty, + co = offset, + operations = op_strings + ) +} + fn current_offset(bit_offset: usize, offset_fns: &[String]) -> String { let base_offset = bit_offset / 8; diff --git a/pnet_macros/tests/run-pass/vec_construct.rs b/pnet_macros/tests/run-pass/vec_construct.rs new file mode 100644 index 00000000..bf9b46a8 --- /dev/null +++ b/pnet_macros/tests/run-pass/vec_construct.rs @@ -0,0 +1,60 @@ +extern crate pnet_macros; +extern crate pnet_macros_support; +use pnet_macros::packet; +use pnet_macros_support::packet::PrimitiveValues; + +#[packet] +pub struct PacketWithVecConstruct { + banana: u8, + #[length_fn = "length_fn"] + #[construct_with(u64, u64)] + tomatoes: Vec, + #[payload] + payload: Vec, +} + +fn length_fn(_: &PacketWithVecConstructPacket) -> usize { + 48 +} + +#[derive(PartialEq, PartialOrd, Eq, Ord, Clone, Copy, Debug)] +pub struct Identity(pub(crate) [u8; Identity::LEN]); + +impl Identity { + const LEN: usize = 16; + + pub fn new(b0: u64, b1: u64) -> Identity { + let mut buf = [0u8; 16]; + buf[0..8].copy_from_slice(&b0.to_be_bytes()); + buf[8..16].copy_from_slice(&b1.to_be_bytes()); + Identity(buf) + } +} + +impl PrimitiveValues for Identity { + type T = (u64, u64); + fn to_primitive_values(&self) -> (u64, u64) { + ( + u64::from_be_bytes(self.0[0..8].try_into().unwrap()), + u64::from_be_bytes(self.0[8..16].try_into().unwrap()), + ) + } +} + +fn main() { + let test = PacketWithVecConstruct { + banana: 1, + tomatoes: vec![ + Identity([2u8; 16]), + Identity([3u8; 16]), + Identity([4u8; 16]) + ], + payload: vec![], + }; + + let mut buf = vec![0; PacketWithVecConstructPacket::packet_size(&test)]; + let mut packet = MutablePacketWithVecConstructPacket::new(&mut buf).unwrap(); + packet.populate(&test); + assert_eq!(packet.get_banana(), test.banana); + assert_eq!(packet.get_tomatoes(), test.tomatoes); +}