diff --git a/CHANGELOG.md b/CHANGELOG.md index 14e73e42cc..b72d218bb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Derived `Ord`, `PartialOrd` for `unistd::Pid` (#[1189](https://github.com/nix-rust/nix/pull/1189)) - Added `select::FdSet::fds` method to iterate over file descriptors in a set. ([#1207](https://github.com/nix-rust/nix/pull/1207)) +- Added support for `sendmmsg` and `recvmmsg` calls + (#[1208](https://github.com/nix-rust/nix/pull/1208)) ### Changed - Changed `fallocate` return type from `c_int` to `()` (#[1201](https://github.com/nix-rust/nix/pull/1201)) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 6b2d1ad08c..8d7b3d07b5 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -50,6 +50,7 @@ pub use libc::{ // Needed by the cmsg_space macro #[doc(hidden)] pub use libc::{c_uint, CMSG_SPACE}; +use std::mem::MaybeUninit; /// These constants are used to specify the communication semantics /// when creating a socket with [`socket()`](fn.socket.html) @@ -774,11 +775,305 @@ impl<'a> ControlMessage<'a> { pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], flags: MsgFlags, addr: Option<&SockAddr>) -> Result { + let mut mhdr = mem::MaybeUninit::::zeroed(); + let capacity = cmsgs.iter().map(|c| c.space()).sum(); // First size the buffer needed to hold the cmsgs. It must be zeroed, // because subsequent code will not clear the padding bytes. - let cmsg_buffer = vec![0u8; capacity]; + let mut cmsg_buffer = vec![0u8; capacity]; + + pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr); + + let mhdr = unsafe { mhdr.assume_init() }; + + let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) }; + + Errno::result(ret).map(|r| r as usize) +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", +))] +#[derive(Debug)] +pub struct SendMmsgData<'a, I, C> + where + I: AsRef<[IoVec<&'a [u8]>]>, + C: AsRef<[ControlMessage<'a>]> +{ + pub iov: I, + pub cmsgs: C, + pub addr: Option, + pub _lt: std::marker::PhantomData<&'a I>, +} + +/// An extension of `sendmsg` that allows the caller to transmit multiple +/// messages on a socket using a single system call. This has performance +/// benefits for some applications. +/// +/// Allocations are performed for cmsgs and to build `msghdr` buffer +/// +/// # Arguments +/// +/// * `fd`: Socket file descriptor +/// * `data`: Struct that implements `IntoIterator` with `SendMmsgData` items +/// * `flags`: Optional flags passed directly to the operating system. +/// +/// # Returns +/// `Vec` with numbers of sent bytes on each sent message. +/// +/// # References +/// [`sendmsg`](fn.sendmsg.html) +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", +))] +pub fn sendmmsg<'a, I, C>( + fd: RawFd, + data: impl std::iter::IntoIterator>, + flags: MsgFlags +) -> Result> + where + I: AsRef<[IoVec<&'a [u8]>]> + 'a, + C: AsRef<[ControlMessage<'a>]> + 'a, +{ + let iter = data.into_iter(); + + let (min_size, max_size) = iter.size_hint(); + let reserve_items = max_size.unwrap_or(min_size); + + let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; + + let mut cmsgs_buffer = vec![0u8; 0]; + + iter.enumerate().for_each(|(i, d)| { + if output.len() < i { + output.resize(i, MaybeUninit::zeroed()); + } + + let element = &mut output[i]; + + let cmsgs_start = cmsgs_buffer.len(); + let cmsgs_required_capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum(); + let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity; + cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0); + + unsafe { + pack_mhdr( + &mut (*element.as_mut_ptr()).msg_hdr, + &mut cmsgs_buffer[cmsgs_start..], + &d.iov, + &d.cmsgs, + d.addr.as_ref() + ) + }; + }); + + let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; + + let ret = unsafe { libc::sendmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _) }; + + let sent_messages = Errno::result(ret)? as usize; + let mut sent_bytes = Vec::with_capacity(sent_messages); + unsafe { sent_bytes.set_len(sent_messages) }; + + for i in 0..sent_messages { + sent_bytes[i] = initialized_data[i].msg_len as usize; + } + + Ok(sent_bytes) +} + + +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +#[derive(Debug)] +pub struct RecvMmsgData<'a, I> + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + pub iov: I, + pub cmsg_buffer: Option<&'a mut Vec>, +} + +/// An extension of `recvmsg` that allows the caller to receive multiple +/// messages from a socket using a single system call. This has +/// performance benefits for some applications. +/// +/// `iov` and `cmsg_buffer` should be constructed similarly to `recvmsg` +/// +/// Multiple allocations are performed +/// +/// # Arguments +/// +/// * `fd`: Socket file descriptor +/// * `data`: Struct that implements `IntoIterator` with `RecvMmsgData` items +/// * `flags`: Optional flags passed directly to the operating system. +/// +/// # RecvMmsgData +/// +/// * `iov`: Scatter-gather list of buffers to receive the message +/// * `cmsg_buffer`: Space to receive ancillary data. Should be created by +/// [`cmsg_space!`](macro.cmsg_space.html) +/// +/// # Returns +/// A `Vec` with multiple `RecvMsg`, one per received message +/// +/// # References +/// - [`recvmsg`](fn.recvmsg.html) +/// - [`RecvMsg`](struct.RecvMsg.html) +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", +))] +pub fn recvmmsg<'a, I>( + fd: RawFd, + data: impl std::iter::IntoIterator>, + flags: MsgFlags, + timeout: Option +) -> Result>> + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + let iter = data.into_iter(); + + let (min_size, max_size) = iter.size_hint(); + let reserve_items = max_size.unwrap_or(min_size); + + let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; + let mut address: Vec> = vec![MaybeUninit::uninit(); reserve_items]; + + let results: Vec<_> = iter.enumerate().map(|(i, d)| { + if output.len() < i { + output.resize(i, MaybeUninit::zeroed()); + address.resize(i, MaybeUninit::uninit()); + } + + let element = &mut output[i]; + + let msg_controllen = unsafe { + recv_pack_mhdr( + &mut (*element.as_mut_ptr()).msg_hdr, + d.iov.as_ref(), + &mut d.cmsg_buffer, + &mut address[i] + ) + }; + + (msg_controllen as usize, &mut d.cmsg_buffer) + }).collect(); + + let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; + + let timeout = if let Some(mut t) = timeout { + t.as_mut() as *mut libc::timespec + } else { + ptr::null_mut() + }; + + let ret = unsafe { libc::recvmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _, timeout) }; + + let r = Errno::result(ret)?; + + Ok(initialized_data + .into_iter() + .zip(address.into_iter()) + .zip(results.into_iter()) + .map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| { + unsafe { + read_mhdr( + mmsghdr.msg_hdr, + r as isize, + msg_controllen, + address, + cmsg_buffer + ) + } + }) + .collect()) +} + +unsafe fn read_mhdr<'a, 'b>( + mhdr: msghdr, + r: isize, + msg_controllen: usize, + address: MaybeUninit, + cmsg_buffer: &'a mut Option<&'b mut Vec> +) -> RecvMsg<'b> { + let cmsghdr = { + if mhdr.msg_controllen > 0 { + // got control message(s) + cmsg_buffer + .as_mut() + .unwrap() + .set_len(mhdr.msg_controllen as usize); + debug_assert!(!mhdr.msg_control.is_null()); + debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); + CMSG_FIRSTHDR(&mhdr as *const msghdr) + } else { + ptr::null() + }.as_ref() + }; + + let address = sockaddr_storage_to_addr( + &address.assume_init(), + mhdr.msg_namelen as usize + ).ok(); + + RecvMsg { + bytes: r as usize, + cmsghdr, + address, + flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), + mhdr, + } +} + +unsafe fn recv_pack_mhdr<'a, I>( + out: *mut msghdr, + iov: I, + cmsg_buffer: &mut Option<&mut Vec>, + address: &mut mem::MaybeUninit +) -> usize + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + let (msg_control, msg_controllen) = cmsg_buffer.as_mut() + .map(|v| (v.as_mut_ptr(), v.capacity())) + .unwrap_or((ptr::null_mut(), 0)); + + (*out).msg_name = address.as_mut_ptr() as *mut c_void; + (*out).msg_namelen = mem::size_of::() as socklen_t; + (*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec; + (*out).msg_iovlen = iov.as_ref().len() as _; + (*out).msg_control = msg_control as *mut c_void; + (*out).msg_controllen = msg_controllen as _; + (*out).msg_flags = 0; + + msg_controllen +} + + +fn pack_mhdr<'a, I, C>( + out: *mut msghdr, + cmsg_buffer: &mut [u8], + iov: I, + cmsgs: C, + addr: Option<&SockAddr> +) + where + I: AsRef<[IoVec<&'a [u8]>]>, + C: AsRef<[ControlMessage<'a>]> +{ + let cmsg_capacity = cmsg_buffer.len(); // Next encode the sending address, if provided let (name, namelen) = match addr { @@ -790,45 +1085,38 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], }; // The message header must be initialized before the individual cmsgs. - let cmsg_ptr = if capacity > 0 { + let cmsg_ptr = if cmsg_capacity > 0 { cmsg_buffer.as_ptr() as *mut c_void } else { ptr::null_mut() }; - let mhdr = unsafe { - // Musl's msghdr has private fields, so this is the only way to - // initialize it. - let mut mhdr = mem::MaybeUninit::::zeroed(); - let p = mhdr.as_mut_ptr(); - (*p).msg_name = name as *mut _; - (*p).msg_namelen = namelen; + // Musl's msghdr has private fields, so this is the only way to + // initialize it. + unsafe { + (*out).msg_name = name as *mut _; + (*out).msg_namelen = namelen; // transmute iov into a mutable pointer. sendmsg doesn't really mutate // the buffer, but the standard says that it takes a mutable pointer - (*p).msg_iov = iov.as_ptr() as *mut _; - (*p).msg_iovlen = iov.len() as _; - (*p).msg_control = cmsg_ptr; - (*p).msg_controllen = capacity as _; - (*p).msg_flags = 0; - mhdr.assume_init() - }; + (*out).msg_iov = iov.as_ref().as_ptr() as *mut _; + (*out).msg_iovlen = iov.as_ref().len() as _; + (*out).msg_control = cmsg_ptr; + (*out).msg_controllen = cmsg_capacity as _; + (*out).msg_flags = 0; + } // Encode each cmsg. This must happen after initializing the header because // CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields. // CMSG_FIRSTHDR is always safe - let mut pmhdr: *mut cmsghdr = unsafe{CMSG_FIRSTHDR(&mhdr as *const msghdr)}; - for cmsg in cmsgs { + let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(out) }; + for cmsg in cmsgs.as_ref() { assert_ne!(pmhdr, ptr::null_mut()); // Safe because we know that pmhdr is valid, and we initialized it with // sufficient space - unsafe { cmsg.encode_into(pmhdr) }; + unsafe { cmsg.encode_into(pmhdr); } // Safe because mhdr is valid - pmhdr = unsafe{CMSG_NXTHDR(&mhdr as *const msghdr, pmhdr)}; + pmhdr = unsafe { CMSG_NXTHDR(out, pmhdr) }; } - - let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) }; - - Errno::result(ret).map(|r| r as usize) } /// Receive message in scatter-gather vectors from a socket, and @@ -849,58 +1137,20 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>], mut cmsg_buffer: Option<&'a mut Vec>, flags: MsgFlags) -> Result> { + let mut out = mem::MaybeUninit::::zeroed(); let mut address = mem::MaybeUninit::uninit(); - let (msg_control, msg_controllen) = cmsg_buffer.as_mut() - .map(|v| (v.as_mut_ptr(), v.capacity())) - .unwrap_or((ptr::null_mut(), 0)); - let mut mhdr = { - unsafe { - // Musl's msghdr has private fields, so this is the only way to - // initialize it. - let mut mhdr = mem::MaybeUninit::::zeroed(); - let p = mhdr.as_mut_ptr(); - (*p).msg_name = address.as_mut_ptr() as *mut c_void; - (*p).msg_namelen = mem::size_of::() as socklen_t; - (*p).msg_iov = iov.as_ptr() as *mut iovec; - (*p).msg_iovlen = iov.len() as _; - (*p).msg_control = msg_control as *mut c_void; - (*p).msg_controllen = msg_controllen as _; - (*p).msg_flags = 0; - mhdr.assume_init() - } + + let msg_controllen = unsafe { + recv_pack_mhdr(out.as_mut_ptr(), &iov, &mut cmsg_buffer, &mut address) }; + let mut mhdr = unsafe { out.assume_init() }; + let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; - Errno::result(ret).map(|r| { - let cmsghdr = unsafe { - if mhdr.msg_controllen > 0 { - // got control message(s) - cmsg_buffer - .as_mut() - .unwrap() - .set_len(mhdr.msg_controllen as usize); - debug_assert!(!mhdr.msg_control.is_null()); - debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); - CMSG_FIRSTHDR(&mhdr as *const msghdr) - } else { - ptr::null() - }.as_ref() - }; + let r = Errno::result(ret)?; - let address = unsafe { - sockaddr_storage_to_addr(&address.assume_init(), - mhdr.msg_namelen as usize - ).ok() - }; - RecvMsg { - bytes: r as usize, - cmsghdr, - address, - flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), - mhdr, - } - }) + Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address, &mut cmsg_buffer) }) } diff --git a/src/sys/time.rs b/src/sys/time.rs index 606bbd9d9b..06475001b9 100644 --- a/src/sys/time.rs +++ b/src/sys/time.rs @@ -67,6 +67,12 @@ impl AsRef for TimeSpec { } } +impl AsMut for TimeSpec { + fn as_mut(&mut self) -> &mut timespec { + &mut self.0 + } +} + impl Ord for TimeSpec { // The implementation of cmp is simplified by assuming that the struct is // normalized. That is, tv_nsec must always be within [0, 1_000_000_000) @@ -259,6 +265,12 @@ impl AsRef for TimeVal { } } +impl AsMut for TimeVal { + fn as_mut(&mut self) -> &mut timeval { + &mut self.0 + } +} + impl Ord for TimeVal { // The implementation of cmp is simplified by assuming that the struct is // normalized. That is, tv_usec must always be within [0, 1_000_000) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index bd5c373bc7..91f7be580f 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -226,6 +226,141 @@ mod recvfrom { // UDP sockets should set the from address assert_eq!(AddressFamily::Inet, from.unwrap().family()); } + + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + ))] + #[test] + pub fn udp_sendmmsg() { + use nix::sys::uio::IoVec; + + let std_sa = SocketAddr::from_str("127.0.0.1:6793").unwrap(); + let std_sa2 = SocketAddr::from_str("127.0.0.1:6794").unwrap(); + let inet_addr = InetAddr::from_std(&std_sa); + let inet_addr2 = InetAddr::from_std(&std_sa2); + let sock_addr = SockAddr::new_inet(inet_addr); + let sock_addr2 = SockAddr::new_inet(inet_addr2); + + let rsock = socket(AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None + ).unwrap(); + bind(rsock, &sock_addr).unwrap(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + ).expect("send socket failed"); + + let from = sendrecv(rsock, ssock, move |s, m, flags| { + let iov = [IoVec::from_slice(m)]; + let mut msgs = Vec::new(); + msgs.push( + SendMmsgData { + iov: &iov, + cmsgs: &[], + addr: Some(sock_addr), + _lt: Default::default(), + }); + + let batch_size = 15; + + for _ in 0..batch_size { + msgs.push( + SendMmsgData { + iov: &iov, + cmsgs: &[], + addr: Some(sock_addr2), + _lt: Default::default(), + } + ); + } + sendmmsg(s, msgs.iter(), flags) + .map(move |sent_bytes| { + assert!(sent_bytes.len() >= 1); + for sent in &sent_bytes { + assert_eq!(*sent, m.len()); + } + sent_bytes.len() + }) + }); + // UDP sockets should set the from address + assert_eq!(AddressFamily::Inet, from.unwrap().family()); + } + + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "netbsd", + ))] + #[test] + pub fn udp_recvmmsg() { + use nix::sys::uio::IoVec; + use nix::sys::socket::{MsgFlags, recvmmsg}; + + const NUM_MESSAGES_SENT: usize = 2; + const DATA: [u8; 2] = [1,2]; + + let std_sa = SocketAddr::from_str("127.0.0.1:6798").unwrap(); + let inet_addr = InetAddr::from_std(&std_sa); + let sock_addr = SockAddr::new_inet(inet_addr); + + let rsock = socket(AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None + ).unwrap(); + bind(rsock, &sock_addr).unwrap(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + ).expect("send socket failed"); + + + let send_thread = thread::spawn(move || { + for _ in 0..NUM_MESSAGES_SENT { + sendto(ssock, &DATA[..], &sock_addr, MsgFlags::empty()).unwrap(); + } + }); + + let mut msgs = std::collections::LinkedList::new(); + + // Buffers to receive exactly `NUM_MESSAGES_SENT` messages + let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT]; + let iovs: Vec<_> = receive_buffers.iter_mut().map(|buf| { + [IoVec::from_mut_slice(&mut buf[..])] + }).collect(); + + for iov in &iovs { + msgs.push_back(RecvMmsgData { + iov: iov, + cmsg_buffer: None, + }) + }; + + let res = recvmmsg(rsock, &mut msgs, MsgFlags::empty(), None).expect("recvmmsg"); + assert_eq!(res.len(), DATA.len()); + + for RecvMsg { address, bytes, .. } in res.into_iter() { + assert_eq!(AddressFamily::Inet, address.unwrap().family()); + assert_eq!(DATA.len(), bytes); + } + + for buf in &receive_buffers { + assert_eq!(&buf[..DATA.len()], DATA); + } + + send_thread.join().unwrap(); + } } // Test error handling of our recvmsg wrapper