From 1cdebf663c4a22c8e1aa153798e15a0d6d5c1257 Mon Sep 17 00:00:00 2001 From: Gleb Pomykalov Date: Tue, 21 Apr 2020 21:43:04 +0300 Subject: [PATCH] Expect ExactSizeIterator on sendmmsg/recvmmsg and simplify unsafe code. --- src/sys/socket/mod.rs | 84 ++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 45 deletions(-) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 0b3283ba16..e6ec2a0477 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -50,7 +50,6 @@ pub use libc::{ // Needed by the cmsg_space macro #[doc(hidden)] pub use libc::{c_uint, CMSG_SPACE}; -use std::mem::MaybeUninit; /// These constants are used to specify the communication semantics /// when creating a socket with [`socket()`](fn.socket.html) @@ -783,7 +782,7 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], // because subsequent code will not clear the padding bytes. let mut cmsg_buffer = vec![0u8; capacity]; - pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr); + pack_mhdr_to_send(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr); let mhdr = unsafe { mhdr.assume_init() }; @@ -837,7 +836,8 @@ pub struct SendMmsgData<'a, I, C> ))] pub fn sendmmsg<'a, I, C>( fd: RawFd, - data: impl std::iter::IntoIterator>, + data: impl std::iter::IntoIterator, + IntoIter=impl ExactSizeIterator + Iterator>>, flags: MsgFlags ) -> Result> where @@ -846,18 +846,16 @@ pub fn sendmmsg<'a, I, C>( { let iter = data.into_iter(); - let (min_size, max_size) = iter.size_hint(); - let reserve_items = max_size.unwrap_or(min_size); + let num_messages = iter.len(); - let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; + let mut output = Vec::::with_capacity(num_messages); + unsafe { + output.set_len(num_messages); + } let mut cmsgs_buffer = vec![0u8; 0]; iter.enumerate().for_each(|(i, d)| { - if output.len() < i { - output.resize(i, MaybeUninit::zeroed()); - } - let element = &mut output[i]; let cmsgs_start = cmsgs_buffer.len(); @@ -865,15 +863,13 @@ pub fn sendmmsg<'a, I, C>( let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity; cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0); - unsafe { - pack_mhdr( - &mut (*element.as_mut_ptr()).msg_hdr, - &mut cmsgs_buffer[cmsgs_start..], - &d.iov, - &d.cmsgs, - d.addr.as_ref() - ) - }; + pack_mhdr_to_send( + &mut element.msg_hdr, + &mut cmsgs_buffer[cmsgs_start..], + &d.iov, + &d.cmsgs, + d.addr.as_ref() + ); }); let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; @@ -941,7 +937,8 @@ pub struct RecvMmsgData<'a, I> ))] pub fn recvmmsg<'a, I>( fd: RawFd, - data: impl std::iter::IntoIterator>, + data: impl std::iter::IntoIterator, + IntoIter=impl ExactSizeIterator + Iterator>>, flags: MsgFlags, timeout: Option ) -> Result>> @@ -950,23 +947,22 @@ pub fn recvmmsg<'a, I>( { let iter = data.into_iter(); - let (min_size, max_size) = iter.size_hint(); - let reserve_items = max_size.unwrap_or(min_size); + let num_messages = iter.len(); - let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; - let mut address: Vec> = vec![MaybeUninit::uninit(); reserve_items]; + let mut output: Vec = Vec::with_capacity(num_messages); + let mut address: Vec = Vec::with_capacity(num_messages); - let results: Vec<_> = iter.enumerate().map(|(i, d)| { - if output.len() < i { - output.resize(i, MaybeUninit::zeroed()); - address.resize(i, MaybeUninit::uninit()); - } + unsafe { + output.set_len(num_messages); + address.set_len(num_messages); + } + let results: Vec<_> = iter.enumerate().map(|(i, d)| { let element = &mut output[i]; let msg_controllen = unsafe { - recv_pack_mhdr( - &mut (*element.as_mut_ptr()).msg_hdr, + pack_mhdr_to_receive( + &mut element.msg_hdr, d.iov.as_ref(), &mut d.cmsg_buffer, &mut address[i] @@ -976,29 +972,27 @@ pub fn recvmmsg<'a, I>( (msg_controllen as usize, &mut d.cmsg_buffer) }).collect(); - let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; - let timeout = if let Some(mut t) = timeout { t.as_mut() as *mut libc::timespec } else { ptr::null_mut() }; - let ret = unsafe { libc::recvmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _, timeout) }; + let ret = unsafe { libc::recvmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _, timeout) }; let r = Errno::result(ret)?; - Ok(initialized_data + Ok(output .into_iter() .zip(address.into_iter()) .zip(results.into_iter()) - .map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| { + .map(|((mmsghdr, mut address), (msg_controllen, cmsg_buffer))| { unsafe { read_mhdr( mmsghdr.msg_hdr, r as isize, msg_controllen, - address, + &mut address, cmsg_buffer ) } @@ -1010,7 +1004,7 @@ unsafe fn read_mhdr<'a, 'b>( mhdr: msghdr, r: isize, msg_controllen: usize, - address: MaybeUninit, + address: *mut sockaddr_storage, cmsg_buffer: &'a mut Option<&'b mut Vec> ) -> RecvMsg<'b> { let cmsghdr = { @@ -1029,7 +1023,7 @@ unsafe fn read_mhdr<'a, 'b>( }; let address = sockaddr_storage_to_addr( - &address.assume_init(), + &*address , mhdr.msg_namelen as usize ).ok(); @@ -1042,11 +1036,11 @@ unsafe fn read_mhdr<'a, 'b>( } } -unsafe fn recv_pack_mhdr<'a, I>( +unsafe fn pack_mhdr_to_receive<'a, I>( out: *mut msghdr, iov: I, cmsg_buffer: &mut Option<&mut Vec>, - address: &mut mem::MaybeUninit + address: *mut sockaddr_storage, ) -> usize where I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, @@ -1055,7 +1049,7 @@ unsafe fn recv_pack_mhdr<'a, I>( .map(|v| (v.as_mut_ptr(), v.capacity())) .unwrap_or((ptr::null_mut(), 0)); - (*out).msg_name = address.as_mut_ptr() as *mut c_void; + (*out).msg_name = address as *mut c_void; (*out).msg_namelen = mem::size_of::() as socklen_t; (*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec; (*out).msg_iovlen = iov.as_ref().len() as _; @@ -1067,7 +1061,7 @@ unsafe fn recv_pack_mhdr<'a, I>( } -fn pack_mhdr<'a, I, C>( +fn pack_mhdr_to_send<'a, I, C>( out: *mut msghdr, cmsg_buffer: &mut [u8], iov: I, @@ -1146,7 +1140,7 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>], let mut address = mem::MaybeUninit::uninit(); let msg_controllen = unsafe { - recv_pack_mhdr(out.as_mut_ptr(), &iov, &mut cmsg_buffer, &mut address) + pack_mhdr_to_receive(out.as_mut_ptr(), &iov, &mut cmsg_buffer, address.as_mut_ptr()) }; let mut mhdr = unsafe { out.assume_init() }; @@ -1155,7 +1149,7 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>], let r = Errno::result(ret)?; - Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address, &mut cmsg_buffer) }) + Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.as_mut_ptr(), &mut cmsg_buffer) }) }