Skip to content

Commit

Permalink
Expect ExactSizeIterator on sendmmsg/recvmmsg and simplify unsafe code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gleb Pomykalov committed Apr 21, 2020
1 parent 77421b1 commit 1cdebf6
Showing 1 changed file with 39 additions and 45 deletions.
84 changes: 39 additions & 45 deletions src/sys/socket/mod.rs
Expand Up @@ -50,7 +50,6 @@ pub use libc::{
// Needed by the cmsg_space macro
#[doc(hidden)]
pub use libc::{c_uint, CMSG_SPACE};
use std::mem::MaybeUninit;

/// These constants are used to specify the communication semantics
/// when creating a socket with [`socket()`](fn.socket.html)
Expand Down Expand Up @@ -783,7 +782,7 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
// because subsequent code will not clear the padding bytes.
let mut cmsg_buffer = vec![0u8; capacity];

pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr);
pack_mhdr_to_send(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr);

let mhdr = unsafe { mhdr.assume_init() };

Expand Down Expand Up @@ -837,7 +836,8 @@ pub struct SendMmsgData<'a, I, C>
))]
pub fn sendmmsg<'a, I, C>(
fd: RawFd,
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C>>,
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C>,
IntoIter=impl ExactSizeIterator + Iterator<Item=&'a SendMmsgData<'a, I, C>>>,
flags: MsgFlags
) -> Result<Vec<usize>>
where
Expand All @@ -846,34 +846,30 @@ pub fn sendmmsg<'a, I, C>(
{
let iter = data.into_iter();

let (min_size, max_size) = iter.size_hint();
let reserve_items = max_size.unwrap_or(min_size);
let num_messages = iter.len();

let mut output: Vec<MaybeUninit<libc::mmsghdr>> = vec![MaybeUninit::zeroed(); reserve_items];
let mut output = Vec::<libc::mmsghdr>::with_capacity(num_messages);
unsafe {
output.set_len(num_messages);
}

let mut cmsgs_buffer = vec![0u8; 0];

iter.enumerate().for_each(|(i, d)| {
if output.len() < i {
output.resize(i, MaybeUninit::zeroed());
}

let element = &mut output[i];

let cmsgs_start = cmsgs_buffer.len();
let cmsgs_required_capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum();
let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity;
cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0);

unsafe {
pack_mhdr(
&mut (*element.as_mut_ptr()).msg_hdr,
&mut cmsgs_buffer[cmsgs_start..],
&d.iov,
&d.cmsgs,
d.addr.as_ref()
)
};
pack_mhdr_to_send(
&mut element.msg_hdr,
&mut cmsgs_buffer[cmsgs_start..],
&d.iov,
&d.cmsgs,
d.addr.as_ref()
);
});

let mut initialized_data = unsafe { mem::transmute::<_, Vec<libc::mmsghdr>>(output) };
Expand Down Expand Up @@ -941,7 +937,8 @@ pub struct RecvMmsgData<'a, I>
))]
pub fn recvmmsg<'a, I>(
fd: RawFd,
data: impl std::iter::IntoIterator<Item=&'a mut RecvMmsgData<'a, I>>,
data: impl std::iter::IntoIterator<Item=&'a mut RecvMmsgData<'a, I>,
IntoIter=impl ExactSizeIterator + Iterator<Item=&'a mut RecvMmsgData<'a, I>>>,
flags: MsgFlags,
timeout: Option<crate::sys::time::TimeSpec>
) -> Result<Vec<RecvMsg<'a>>>
Expand All @@ -950,23 +947,22 @@ pub fn recvmmsg<'a, I>(
{
let iter = data.into_iter();

let (min_size, max_size) = iter.size_hint();
let reserve_items = max_size.unwrap_or(min_size);
let num_messages = iter.len();

let mut output: Vec<MaybeUninit<libc::mmsghdr>> = vec![MaybeUninit::zeroed(); reserve_items];
let mut address: Vec<MaybeUninit<sockaddr_storage>> = vec![MaybeUninit::uninit(); reserve_items];
let mut output: Vec<libc::mmsghdr> = Vec::with_capacity(num_messages);
let mut address: Vec<sockaddr_storage> = Vec::with_capacity(num_messages);

let results: Vec<_> = iter.enumerate().map(|(i, d)| {
if output.len() < i {
output.resize(i, MaybeUninit::zeroed());
address.resize(i, MaybeUninit::uninit());
}
unsafe {
output.set_len(num_messages);
address.set_len(num_messages);
}

let results: Vec<_> = iter.enumerate().map(|(i, d)| {
let element = &mut output[i];

let msg_controllen = unsafe {
recv_pack_mhdr(
&mut (*element.as_mut_ptr()).msg_hdr,
pack_mhdr_to_receive(
&mut element.msg_hdr,
d.iov.as_ref(),
&mut d.cmsg_buffer,
&mut address[i]
Expand All @@ -976,29 +972,27 @@ pub fn recvmmsg<'a, I>(
(msg_controllen as usize, &mut d.cmsg_buffer)
}).collect();

let mut initialized_data = unsafe { mem::transmute::<_, Vec<libc::mmsghdr>>(output) };

let timeout = if let Some(mut t) = timeout {
t.as_mut() as *mut libc::timespec
} else {
ptr::null_mut()
};

let ret = unsafe { libc::recvmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _, timeout) };
let ret = unsafe { libc::recvmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _, timeout) };

let r = Errno::result(ret)?;

Ok(initialized_data
Ok(output
.into_iter()
.zip(address.into_iter())
.zip(results.into_iter())
.map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| {
.map(|((mmsghdr, mut address), (msg_controllen, cmsg_buffer))| {
unsafe {
read_mhdr(
mmsghdr.msg_hdr,
r as isize,
msg_controllen,
address,
&mut address,
cmsg_buffer
)
}
Expand All @@ -1010,7 +1004,7 @@ unsafe fn read_mhdr<'a, 'b>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: MaybeUninit<sockaddr_storage>,
address: *mut sockaddr_storage,
cmsg_buffer: &'a mut Option<&'b mut Vec<u8>>
) -> RecvMsg<'b> {
let cmsghdr = {
Expand All @@ -1029,7 +1023,7 @@ unsafe fn read_mhdr<'a, 'b>(
};

let address = sockaddr_storage_to_addr(
&address.assume_init(),
&*address ,
mhdr.msg_namelen as usize
).ok();

Expand All @@ -1042,11 +1036,11 @@ unsafe fn read_mhdr<'a, 'b>(
}
}

unsafe fn recv_pack_mhdr<'a, I>(
unsafe fn pack_mhdr_to_receive<'a, I>(
out: *mut msghdr,
iov: I,
cmsg_buffer: &mut Option<&mut Vec<u8>>,
address: &mut mem::MaybeUninit<sockaddr_storage>
address: *mut sockaddr_storage,
) -> usize
where
I: AsRef<[IoVec<&'a mut [u8]>]> + 'a,
Expand All @@ -1055,7 +1049,7 @@ unsafe fn recv_pack_mhdr<'a, I>(
.map(|v| (v.as_mut_ptr(), v.capacity()))
.unwrap_or((ptr::null_mut(), 0));

(*out).msg_name = address.as_mut_ptr() as *mut c_void;
(*out).msg_name = address as *mut c_void;
(*out).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
(*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec;
(*out).msg_iovlen = iov.as_ref().len() as _;
Expand All @@ -1067,7 +1061,7 @@ unsafe fn recv_pack_mhdr<'a, I>(
}


fn pack_mhdr<'a, I, C>(
fn pack_mhdr_to_send<'a, I, C>(
out: *mut msghdr,
cmsg_buffer: &mut [u8],
iov: I,
Expand Down Expand Up @@ -1146,7 +1140,7 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>],
let mut address = mem::MaybeUninit::uninit();

let msg_controllen = unsafe {
recv_pack_mhdr(out.as_mut_ptr(), &iov, &mut cmsg_buffer, &mut address)
pack_mhdr_to_receive(out.as_mut_ptr(), &iov, &mut cmsg_buffer, address.as_mut_ptr())
};

let mut mhdr = unsafe { out.assume_init() };
Expand All @@ -1155,7 +1149,7 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>],

let r = Errno::result(ret)?;

Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address, &mut cmsg_buffer) })
Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.as_mut_ptr(), &mut cmsg_buffer) })
}


Expand Down

0 comments on commit 1cdebf6

Please sign in to comment.