Skip to content

Commit

Permalink
Initial basic cmsg support for unix
Browse files Browse the repository at this point in the history
  • Loading branch information
brunowonka committed Jun 23, 2022
1 parent f9c1aef commit e7dbdbe
Show file tree
Hide file tree
Showing 6 changed files with 540 additions and 21 deletions.
365 changes: 365 additions & 0 deletions src/cmsg.rs
@@ -0,0 +1,365 @@
use crate::sys;
use std::borrow::Borrow;
use std::convert::TryInto as _;
use std::io::IoSlice;
use std::iter::FromIterator;

#[derive(Debug, Clone)]
struct MsgHdrWalker<B> {
buffer: B,
position: Option<usize>,
}

impl<B: AsRef<[u8]>> MsgHdrWalker<B> {
fn next_ptr(&mut self) -> Option<*const libc::cmsghdr> {
// Build a msghdr so we can use the functionality in libc.
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
let buffer = self.buffer.as_ref();
// SAFETY: We're giving msghdr a mutable pointer to comply with the C
// API. We'll only allow mutation of `cmsghdr`, however if `B` is
// AsMut<[u8]>.
msghdr.msg_control = buffer.as_ptr() as *mut _;
msghdr.msg_controllen = buffer.len().try_into().expect("buffer is too long");

let nxt_hdr = if let Some(position) = self.position {
if position >= buffer.len() {
return None;
}
let cur_hdr = &buffer[position] as *const u8 as *const _;
// Safety: msghdr is a valid pointer and cur_hdr is not null.
unsafe { libc::CMSG_NXTHDR(&msghdr, cur_hdr) }
} else {
// Safety: msghdr is a valid pointer.
unsafe { libc::CMSG_FIRSTHDR(&msghdr) }
};

if nxt_hdr.is_null() {
self.position = Some(buffer.len());
return None;
}

// SAFETY: nxt_hdr always points to data within the buffer, they must be
// part of the same allocation.
let distance = unsafe { (nxt_hdr as *const u8).offset_from(buffer.as_ptr()) };
// nxt_hdr is always ahead of the buffer and not null if we're here,
// meaning the distance is always positive.
self.position = Some(distance.try_into().unwrap());
Some(nxt_hdr)
}

fn next(&mut self) -> Option<(&libc::cmsghdr, &[u8])> {
self.next_ptr().map(|cmsghdr| {
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
let cmsghdr = unsafe { &*cmsghdr };
// SAFETY: data points to buffer and is controlled by control
// message length.
let data = unsafe {
std::slice::from_raw_parts(
data,
(cmsghdr.cmsg_len as usize)
.saturating_sub(std::mem::size_of::<libc::cmsghdr>()),
)
};
(cmsghdr, data)
})
}
}

impl<B: AsRef<[u8]> + AsMut<[u8]>> MsgHdrWalker<B> {
fn next_mut(&mut self) -> Option<(&mut libc::cmsghdr, &mut [u8])> {
match self.next_ptr() {
Some(cmsghdr) => {
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
// SAFETY: The mutable pointer is safe because we're not going to
// vend any concurrent access to the same memory region and B is
// AsMut<[u8]> guaranteeing we have exclusive access to the buffer.
let cmsghdr = cmsghdr as *mut libc::cmsghdr;
let cmsghdr = unsafe { &mut *cmsghdr };

// We'll always yield the entirety of the rest of the buffer.
let distance = unsafe { data.offset_from(self.buffer.as_ref().as_ptr()) };
// The data pointer is always part of the buffer, can't be before
// it.
let distance: usize = distance.try_into().unwrap();
Some((cmsghdr, &mut self.buffer.as_mut()[distance..]))
}
None => None,
}
}
}

/// A wrapper around a buffer that can be used to write ancillary control
/// messages.
#[derive(Debug)]
pub struct CmsgWriter<B> {
walker: MsgHdrWalker<B>,
last_push: usize,
}

impl<B: AsMut<[u8]> + AsRef<[u8]>> CmsgWriter<B> {
/// Creates a new [`CmsgBuffer`] backed by the bytes in `buffer`.
pub fn new(buffer: B) -> Self {
Self {
walker: MsgHdrWalker {
buffer,
position: None,
},
last_push: 0,
}
}

/// Pushes a new control message `m` to the buffer.
///
/// # Panics
///
/// Panics if the contained buffer does not have enough space to fit `m`.
pub fn push(&mut self, m: &Cmsg) {
let (cmsg_level, cmsg_type, size) = m.level_type_size();
let (nxt_hdr, data) = self
.walker
.next_mut()
.unwrap_or_else(|| panic!("can't fit message {:?}", m));
// Safety: All values are passed by copy.
let cmsg_len = unsafe { libc::CMSG_LEN(size) }.try_into().unwrap();
*nxt_hdr = libc::cmsghdr {
cmsg_len,
cmsg_level,
cmsg_type,
};
m.write(&mut data[..size as usize]);
// Always store the space required for the last push because the walker
// maintains its position cursor at the currently written option, we
// must always add the space for the last control message when returning
// the consolidated buffer.
self.last_push = unsafe { libc::CMSG_SPACE(size) } as usize;
}
}

impl<B: AsMut<[u8]> + AsRef<[u8]>> Extend<Cmsg> for CmsgWriter<B> {
fn extend<T: IntoIterator<Item = Cmsg>>(&mut self, iter: T) {
for cmsg in iter {
self.push(&cmsg)
}
}
}

impl<C: Borrow<Cmsg>> FromIterator<C> for CmsgWriter<Vec<u8>> {
fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self {
let mut buff = CmsgWriter::new(vec![]);
for cmsg in iter {
let cmsg = cmsg.borrow();
buff.walker
.buffer
.resize(buff.walker.buffer.len() + cmsg.space(), 0);
buff.push(&cmsg)
}
buff
}
}

impl<B: AsRef<[u8]>> CmsgWriter<B> {
pub(crate) fn io_slice(&self) -> IoSlice<'_> {
IoSlice::new(self.buffer())
}

pub(crate) fn buffer(&self) -> &[u8] {
if let Some(position) = self.walker.position {
&self.walker.buffer.as_ref()[..position + self.last_push]
} else {
&[]
}
}
}

/// An iterator over received control messages.
#[derive(Debug, Clone)]
pub struct CmsgIter<'a> {
walker: MsgHdrWalker<&'a [u8]>,
}

impl<'a> CmsgIter<'a> {
pub(crate) fn new(buffer: &'a [u8]) -> Self {
Self {
walker: MsgHdrWalker {
buffer,
position: None,
},
}
}
}

impl<'a> Iterator for CmsgIter<'a> {
type Item = Cmsg;

fn next(&mut self) -> Option<Self::Item> {
self.walker.next().map(
|(
libc::cmsghdr {
cmsg_len: _,
cmsg_level,
cmsg_type,
},
data,
)| Cmsg::from_raw(*cmsg_level, *cmsg_type, data),
)
}
}

/// An unknown control message.
#[derive(Debug, Eq, PartialEq)]
pub struct UnknownCmsg {
cmsg_level: libc::c_int,
cmsg_type: libc::c_int,
}

/// Control messages.
#[derive(Debug, Eq, PartialEq)]
pub enum Cmsg {
/// The `IP_TTL` control message.
IpTtl(u8),
/// The `IPV6_PKTINFO` control message.
Ipv6PktInfo {
/// The address the packet is destined to/received from. Equivalent to
/// `in6_pktinfo.ipi6_addr`.
addr: std::net::Ipv6Addr,
/// The interface index the packet is destined to/received from.
/// Equivalent to `in6_pktinfo.ipi6_ifindex`.
ifindex: u32,
},
/// An unrecognized control message.
Unknown(UnknownCmsg),
}

impl Cmsg {
/// Returns the amount of buffer space required to hold this option.
pub fn space(&self) -> usize {
let (_, _, size) = self.level_type_size();
// Safety: All values are passed by copy.
let size = unsafe { libc::CMSG_SPACE(size) };
size as usize
}

fn level_type_size(&self) -> (libc::c_int, libc::c_int, libc::c_uint) {
match self {
Cmsg::IpTtl(_) => (
libc::IPPROTO_IP,
libc::IP_TTL,
// TTL is encoded as a u32.
std::mem::size_of::<u32>() as libc::c_uint,
),
Cmsg::Ipv6PktInfo { .. } => (
libc::IPPROTO_IPV6,
libc::IPV6_PKTINFO,
std::mem::size_of::<libc::in6_pktinfo>() as libc::c_uint,
),
Cmsg::Unknown(UnknownCmsg {
cmsg_level,
cmsg_type,
}) => (*cmsg_level, *cmsg_type, 0),
}
}

fn write(&self, buffer: &mut [u8]) {
match self {
Cmsg::IpTtl(ttl) => {
let value: u32 = (*ttl).into();
let value = value.to_ne_bytes();
(&mut buffer[..value.len()]).copy_from_slice(&value[..]);
}
Cmsg::Ipv6PktInfo { addr, ifindex } => {
let pktinfo = libc::in6_pktinfo {
ipi6_addr: sys::to_in6_addr(addr),
ipi6_ifindex: *ifindex,
};
let size = std::mem::size_of::<libc::in6_pktinfo>();
assert_eq!(buffer.len(), size);
// Safety: `pktinfo` is valid for reads for its size in bytes.
// `buffer` is valid for write for the same length, as
// guaranteed by the assertion above. Copy unit is byte, so
// alignment is okay. The two regions do not overlap.
unsafe {
std::ptr::copy_nonoverlapping(
&pktinfo as *const libc::in6_pktinfo as *const _,
buffer.as_mut_ptr(),
size,
)
}
}
Cmsg::Unknown(_) => {
// NOTE: We don't actually allow users of the public API
// serialize unknown control messages, but we use this code path
// for testing.
}
}
}

fn from_raw(cmsg_level: libc::c_int, cmsg_type: libc::c_int, bytes: &[u8]) -> Self {
match (cmsg_level, cmsg_type) {
(libc::IPPROTO_IP, libc::IP_TTL) => {
assert!(bytes.len() >= std::mem::size_of::<u32>(), "{:?}", bytes);
Cmsg::IpTtl(bytes[0])
}
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
let mut pktinfo = unsafe { std::mem::zeroed::<libc::in6_pktinfo>() };
let size = std::mem::size_of::<libc::in6_pktinfo>();
assert!(bytes.len() >= size, "{:?}", bytes);
// Safety: `pktinfo` is valid for writes for its size in bytes.
// `buffer` is valid for read for the same length, as
// guaranteed by the assertion above. Copy unit is byte, so
// alignment is okay. The two regions do not overlap.
unsafe {
std::ptr::copy_nonoverlapping(
bytes.as_ptr(),
&mut pktinfo as *mut libc::in6_pktinfo as *mut _,
size,
)
}
Cmsg::Ipv6PktInfo {
addr: sys::from_in6_addr(pktinfo.ipi6_addr),
ifindex: pktinfo.ipi6_ifindex,
}
}
(cmsg_level, cmsg_type) => Cmsg::Unknown(UnknownCmsg {
cmsg_level,
cmsg_type,
}),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn ser_deser() {
let cmsgs = [
Cmsg::IpTtl(2),
Cmsg::Ipv6PktInfo {
addr: std::net::Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
ifindex: 13,
},
Cmsg::Unknown(UnknownCmsg {
cmsg_level: 12345678,
cmsg_type: 87654321,
}),
];
let buffer: CmsgWriter<_> = cmsgs.iter().collect();
let deser = CmsgIter::new(buffer.buffer()).collect::<Vec<_>>();
assert_eq!(&cmsgs[..], &deser[..]);
}

#[test]
#[should_panic]
fn ser_insufficient_space_panics() {
let mut buffer = CmsgWriter::new([0; 3]);
buffer.push(&Cmsg::IpTtl(2));
}

#[test]
fn empty_deser() {
assert_eq!(CmsgIter::new(&[]).next(), None);
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Expand Up @@ -115,6 +115,8 @@ macro_rules! from {
};
}

#[cfg(unix)]
mod cmsg;
mod sockaddr;
mod socket;
mod sockref;
Expand All @@ -141,6 +143,9 @@ pub use sockref::SockRef;
)))]
pub use socket::InterfaceIndexOrAddress;

#[cfg(unix)]
pub use cmsg::{Cmsg, CmsgIter, CmsgWriter};

/// Specification of the communication domain for a socket.
///
/// This is a newtype wrapper around an integer which provides a nicer API in
Expand Down

0 comments on commit e7dbdbe

Please sign in to comment.