From b5e28b72790ca83bca4b4dd08490d2b3e07f1c3e Mon Sep 17 00:00:00 2001 From: Gleb Pomykalov Date: Wed, 8 Apr 2020 12:41:56 +0300 Subject: [PATCH] Support sendmmsg/recvmmsg --- CHANGELOG.md | 2 + src/sys/socket/mod.rs | 368 ++++++++++++++++++++++++++++++++-------- src/sys/time.rs | 12 ++ test/sys/test_socket.rs | 118 +++++++++++++ 4 files changed, 426 insertions(+), 74 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14e73e42cc..9828d7a46c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Derived `Ord`, `PartialOrd` for `unistd::Pid` (#[1189](https://github.com/nix-rust/nix/pull/1189)) - Added `select::FdSet::fds` method to iterate over file descriptors in a set. ([#1207](https://github.com/nix-rust/nix/pull/1207)) +- Added support for `sendmmsg` and `recvmmsg` calls on Linux and FreeBSD + (#[1208](https://github.com/nix-rust/nix/pull/1208)) ### Changed - Changed `fallocate` return type from `c_int` to `()` (#[1201](https://github.com/nix-rust/nix/pull/1201)) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 6b2d1ad08c..a5b89a9139 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -50,6 +50,7 @@ pub use libc::{ // Needed by the cmsg_space macro #[doc(hidden)] pub use libc::{c_uint, CMSG_SPACE}; +use std::mem::MaybeUninit; /// These constants are used to specify the communication semantics /// when creating a socket with [`socket()`](fn.socket.html) @@ -774,61 +775,318 @@ impl<'a> ControlMessage<'a> { pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], flags: MsgFlags, addr: Option<&SockAddr>) -> Result { + let mut mhdr = mem::MaybeUninit::::zeroed(); + let capacity = cmsgs.iter().map(|c| c.space()).sum(); // First size the buffer needed to hold the cmsgs. It must be zeroed, // because subsequent code will not clear the padding bytes. - let cmsg_buffer = vec![0u8; capacity]; + let mut cmsg_buffer = vec![0u8; capacity]; + + unsafe { send_pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr) }; + + let mhdr = unsafe { mhdr.assume_init() }; + + let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) }; + + Errno::result(ret).map(|r| r as usize) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +#[derive(Debug)] +pub struct SendMmsgData<'a, I, C> + where + I: AsRef<[IoVec<&'a [u8]>]>, + C: AsRef<[ControlMessage<'a>]> +{ + pub iov: I, + pub cmsgs: C, + pub addr: Option, + pub _lt: std::marker::PhantomData<&'a I>, +} + +/// An extension of `sendmsg`` that allows the caller to transmit multiple +/// messages on a socket using a single system call. (This has performance +/// benefits for some applications.). Supported on Linux and FreeBSD +/// +/// Allocations are performed for cmsgs and to build `msghdr` buffer +/// +/// # Arguments +/// +/// * `fd`: Socket file descriptor +/// * `data`: Struct that implements `IntoIterator` with `SendMmsgData` items +/// * `flags`: Optional flags passed directly to the operating system. +/// +/// # References +/// [sendmmsg(2)](http://man7.org/linux/man-pages/man2/sendmmsg.2.html) +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +pub fn sendmmsg<'a, I, C>(fd: RawFd, data: impl std::iter::IntoIterator>, flags: MsgFlags) + -> Result<(usize, Vec)> + where + I: AsRef<[IoVec<&'a [u8]>]> + 'a, + C: AsRef<[ControlMessage<'a>]> + 'a, +{ + let iter = data.into_iter(); + + let (min_size, max_size) = iter.size_hint(); + let reserve_items = max_size.unwrap_or(min_size); + + let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; + + let mut cmsgs_buffer = vec![0u8; 0]; + + iter.enumerate().for_each(|(i, d)| { + if output.len() < i { + output.resize(i, MaybeUninit::zeroed()); + } + + let element = &mut output[i]; + + let cmsgs_start = cmsgs_buffer.len(); + let cmsgs_required_capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum(); + let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity; + cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0); + + unsafe { + send_pack_mhdr( + &mut (*element.as_mut_ptr()).msg_hdr, + &mut cmsgs_buffer[cmsgs_start..], + &d.iov, + &d.cmsgs, + d.addr.as_ref() + ) + }; + }); + + let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; + + let ret = unsafe { libc::sendmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _) }; + + let sent_messages = Errno::result(ret)? as usize; + let mut sent_bytes = Vec::with_capacity(sent_messages); + unsafe { sent_bytes.set_len(sent_messages) }; + + for i in 0..sent_messages { + sent_bytes[i] = initialized_data[i].msg_len as usize; + } + + Ok((sent_messages, sent_bytes)) +} + + +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +#[derive(Debug)] +pub struct RecvMmsgData<'a, I> + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + pub iov: I, + pub cmsg_buffer: Option<&'a mut Vec>, +} + +/// An extension of recvmsg(2) that allows the caller to receive multiple +/// messages from a socket using a single system call. (This has +/// performance benefits for some applications.) +/// +/// `iov` and `cmsg_buffer` should be constucted similarly to recvmsg +/// +/// Multiple allocations are performed +/// +/// # Arguments +/// +/// * `fd`: Socket file descriptor +/// * `data`: Struct that implements `IntoIterator` with `RecvMmsgData` items +/// * `flags`: Optional flags passed directly to the operating system. +/// +/// # RecvMmsgData +/// +/// * `iov`: Scatter-gather list of buffers to receive the message +/// * `cmsg_buffer`: Space to receive ancillary data. Should be created by +/// [`cmsg_space!`](macro.cmsg_space.html) +/// +/// # References +/// [recvmmsg(2)](http://man7.org/linux/man-pages/man2/recvmmsg.2.html) +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +pub fn recvmmsg<'a, I>( + fd: RawFd, + data: impl std::iter::IntoIterator>, + flags: MsgFlags, timeout: Option +) -> Result>> + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + let iter = data.into_iter(); + + let (min_size, max_size) = iter.size_hint(); + let reserve_items = max_size.unwrap_or(min_size); + + let mut output: Vec> = vec![MaybeUninit::zeroed(); reserve_items]; + let mut address: Vec> = vec![MaybeUninit::uninit(); reserve_items]; + + let results: Vec<_> = iter.enumerate().map(|(i, d)| { + if output.len() < i { + output.resize(i, MaybeUninit::zeroed()); + address.resize(i, MaybeUninit::uninit()); + } + + let element = &mut output[i]; + + let msg_controllen = unsafe { + recv_pack_mhdr( + &mut (*element.as_mut_ptr()).msg_hdr, + d.iov.as_ref(), + &mut d.cmsg_buffer, + &mut address[i] + ) + }; + + (msg_controllen as usize, &mut d.cmsg_buffer) + }).collect(); + + let mut initialized_data = unsafe { mem::transmute::<_, Vec>(output) }; + + let timeout = if let Some(mut t) = timeout { + t.as_mut() as *mut libc::timespec + } else { + ptr::null_mut() + }; + + let ret = unsafe { libc::recvmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as _, flags.bits() as _, timeout) }; + + let r = Errno::result(ret)?; + + Ok(initialized_data + .into_iter() + .zip(address.into_iter()) + .zip(results.into_iter()) + .map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| { + unsafe { + recv_read_mhdr( + mmsghdr.msg_hdr, + r as isize, + msg_controllen, + address, + cmsg_buffer + ) + } + }) + .collect()) +} + +unsafe fn recv_read_mhdr<'a, 'b>( + mhdr: msghdr, + r: isize, + msg_controllen: usize, + address: MaybeUninit, + cmsg_buffer: &'a mut Option<&'b mut Vec> +) -> RecvMsg<'b> { + let cmsghdr = { + if mhdr.msg_controllen > 0 { + // got control message(s) + cmsg_buffer + .as_mut() + .unwrap() + .set_len(mhdr.msg_controllen as usize); + debug_assert!(!mhdr.msg_control.is_null()); + debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); + CMSG_FIRSTHDR(&mhdr as *const msghdr) + } else { + ptr::null() + }.as_ref() + }; + + let address = sockaddr_storage_to_addr( + &address.assume_init(), + mhdr.msg_namelen as usize + ).ok(); + + RecvMsg { + bytes: r as usize, + cmsghdr, + address, + flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), + mhdr, + } +} + +unsafe fn recv_pack_mhdr<'a, I>( + out: *mut msghdr, + iov: I, + cmsg_buffer: &mut Option<&mut Vec>, + address: &mut mem::MaybeUninit +) -> usize + where + I: AsRef<[IoVec<&'a mut [u8]>]> + 'a, +{ + let (msg_control, msg_controllen) = cmsg_buffer.as_mut() + .map(|v| (v.as_mut_ptr(), v.capacity())) + .unwrap_or((ptr::null_mut(), 0)); + + (*out).msg_name = address.as_mut_ptr() as *mut c_void; + (*out).msg_namelen = mem::size_of::() as socklen_t; + (*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec; + (*out).msg_iovlen = iov.as_ref().len() as _; + (*out).msg_control = msg_control as *mut c_void; + (*out).msg_controllen = msg_controllen as _; + (*out).msg_flags = 0; + + msg_controllen +} + + +unsafe fn send_pack_mhdr<'a, I, C>( + out: *mut msghdr, + cmsg_buffer: &mut [u8], + iov: I, + cmsgs: C, + addr: Option<&SockAddr> +) + where + I: AsRef<[IoVec<&'a [u8]>]>, + C: AsRef<[ControlMessage<'a>]> +{ + let cmsg_capacity = cmsg_buffer.len(); // Next encode the sending address, if provided let (name, namelen) = match addr { Some(addr) => { - let (x, y) = unsafe { addr.as_ffi_pair() }; + let (x, y) = addr.as_ffi_pair(); (x as *const _, y) }, None => (ptr::null(), 0), }; // The message header must be initialized before the individual cmsgs. - let cmsg_ptr = if capacity > 0 { + let cmsg_ptr = if cmsg_capacity > 0 { cmsg_buffer.as_ptr() as *mut c_void } else { ptr::null_mut() }; - let mhdr = unsafe { - // Musl's msghdr has private fields, so this is the only way to - // initialize it. - let mut mhdr = mem::MaybeUninit::::zeroed(); - let p = mhdr.as_mut_ptr(); - (*p).msg_name = name as *mut _; - (*p).msg_namelen = namelen; - // transmute iov into a mutable pointer. sendmsg doesn't really mutate - // the buffer, but the standard says that it takes a mutable pointer - (*p).msg_iov = iov.as_ptr() as *mut _; - (*p).msg_iovlen = iov.len() as _; - (*p).msg_control = cmsg_ptr; - (*p).msg_controllen = capacity as _; - (*p).msg_flags = 0; - mhdr.assume_init() - }; + // Musl's msghdr has private fields, so this is the only way to + // initialize it. + (*out).msg_name = name as *mut _; + (*out).msg_namelen = namelen; + // transmute iov into a mutable pointer. sendmsg doesn't really mutate + // the buffer, but the standard says that it takes a mutable pointer + (*out).msg_iov = iov.as_ref().as_ptr() as *mut _; + (*out).msg_iovlen = iov.as_ref().len() as _; + (*out).msg_control = cmsg_ptr; + (*out).msg_controllen = cmsg_capacity as _; + (*out).msg_flags = 0; // Encode each cmsg. This must happen after initializing the header because // CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields. // CMSG_FIRSTHDR is always safe - let mut pmhdr: *mut cmsghdr = unsafe{CMSG_FIRSTHDR(&mhdr as *const msghdr)}; - for cmsg in cmsgs { + let mut pmhdr: *mut cmsghdr = CMSG_FIRSTHDR(out); + for cmsg in cmsgs.as_ref() { assert_ne!(pmhdr, ptr::null_mut()); // Safe because we know that pmhdr is valid, and we initialized it with // sufficient space - unsafe { cmsg.encode_into(pmhdr) }; + cmsg.encode_into(pmhdr); // Safe because mhdr is valid - pmhdr = unsafe{CMSG_NXTHDR(&mhdr as *const msghdr, pmhdr)}; + pmhdr = CMSG_NXTHDR(out, pmhdr); } - - let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) }; - - Errno::result(ret).map(|r| r as usize) } /// Receive message in scatter-gather vectors from a socket, and @@ -845,62 +1103,24 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], /// /// # References /// [recvmsg(2)](http://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html) -pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>], +pub fn recvmsg<'a>(fd: RawFd, iov: &'a [IoVec<&'a mut [u8]>], mut cmsg_buffer: Option<&'a mut Vec>, flags: MsgFlags) -> Result> { + let mut out = mem::MaybeUninit::::zeroed(); let mut address = mem::MaybeUninit::uninit(); - let (msg_control, msg_controllen) = cmsg_buffer.as_mut() - .map(|v| (v.as_mut_ptr(), v.capacity())) - .unwrap_or((ptr::null_mut(), 0)); - let mut mhdr = { - unsafe { - // Musl's msghdr has private fields, so this is the only way to - // initialize it. - let mut mhdr = mem::MaybeUninit::::zeroed(); - let p = mhdr.as_mut_ptr(); - (*p).msg_name = address.as_mut_ptr() as *mut c_void; - (*p).msg_namelen = mem::size_of::() as socklen_t; - (*p).msg_iov = iov.as_ptr() as *mut iovec; - (*p).msg_iovlen = iov.len() as _; - (*p).msg_control = msg_control as *mut c_void; - (*p).msg_controllen = msg_controllen as _; - (*p).msg_flags = 0; - mhdr.assume_init() - } + + let msg_controllen = unsafe { + recv_pack_mhdr(out.as_mut_ptr(), &iov, &mut cmsg_buffer, &mut address) }; + let mut mhdr = unsafe { out.assume_init() }; + let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; - Errno::result(ret).map(|r| { - let cmsghdr = unsafe { - if mhdr.msg_controllen > 0 { - // got control message(s) - cmsg_buffer - .as_mut() - .unwrap() - .set_len(mhdr.msg_controllen as usize); - debug_assert!(!mhdr.msg_control.is_null()); - debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); - CMSG_FIRSTHDR(&mhdr as *const msghdr) - } else { - ptr::null() - }.as_ref() - }; + let r = Errno::result(ret)?; - let address = unsafe { - sockaddr_storage_to_addr(&address.assume_init(), - mhdr.msg_namelen as usize - ).ok() - }; - RecvMsg { - bytes: r as usize, - cmsghdr, - address, - flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), - mhdr, - } - }) + Ok(unsafe { recv_read_mhdr(mhdr, r, msg_controllen, address, &mut cmsg_buffer) }) } diff --git a/src/sys/time.rs b/src/sys/time.rs index 606bbd9d9b..06475001b9 100644 --- a/src/sys/time.rs +++ b/src/sys/time.rs @@ -67,6 +67,12 @@ impl AsRef for TimeSpec { } } +impl AsMut for TimeSpec { + fn as_mut(&mut self) -> &mut timespec { + &mut self.0 + } +} + impl Ord for TimeSpec { // The implementation of cmp is simplified by assuming that the struct is // normalized. That is, tv_nsec must always be within [0, 1_000_000_000) @@ -259,6 +265,12 @@ impl AsRef for TimeVal { } } +impl AsMut for TimeVal { + fn as_mut(&mut self) -> &mut timeval { + &mut self.0 + } +} + impl Ord for TimeVal { // The implementation of cmp is simplified by assuming that the struct is // normalized. That is, tv_usec must always be within [0, 1_000_000) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index bd5c373bc7..673ef98628 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -226,6 +226,124 @@ mod recvfrom { // UDP sockets should set the from address assert_eq!(AddressFamily::Inet, from.unwrap().family()); } + + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + #[test] + pub fn udp_sendmmsg() { + use nix::sys::uio::IoVec; + + let std_sa = SocketAddr::from_str("127.0.0.1:6793").unwrap(); + let std_sa2 = SocketAddr::from_str("127.0.0.1:6794").unwrap(); + let inet_addr = InetAddr::from_std(&std_sa); + let inet_addr2 = InetAddr::from_std(&std_sa2); + let sock_addr = SockAddr::new_inet(inet_addr); + let sock_addr2 = SockAddr::new_inet(inet_addr2); + + let rsock = socket(AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None + ).unwrap(); + bind(rsock, &sock_addr).unwrap(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + ).expect("send socket failed"); + + let from = sendrecv(rsock, ssock, move |s, m, flags| { + let iov = [IoVec::from_slice(m)]; + let mut msgs = Vec::new(); + msgs.push( + SendMmsgData { + iov: &iov, + cmsgs: &[], + addr: Some(sock_addr), + _lt: Default::default(), + }); + for _ in 0..15 { + msgs.push( + SendMmsgData { + iov: &iov, + cmsgs: &[], + addr: Some(sock_addr2), + _lt: Default::default(), + } + ); + } + sendmmsg(s, msgs.iter(), flags) + .map(move |(sent_messages, sent_bytes)| { + assert!(sent_messages >= 1); + assert_eq!(sent_bytes.len(), sent_messages); + for sent in &sent_bytes { + assert_eq!(*sent, m.len()); + } + sent_messages + }) + }); + // UDP sockets should set the from address + assert_eq!(AddressFamily::Inet, from.unwrap().family()); + } + + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + #[test] + pub fn udp_recvmmsg() { + use nix::sys::uio::IoVec; + use nix::sys::socket::{MsgFlags, recvmmsg}; + + let std_sa = SocketAddr::from_str("127.0.0.1:6798").unwrap(); + let inet_addr = InetAddr::from_std(&std_sa); + let sock_addr = SockAddr::new_inet(inet_addr); + + let rsock = socket(AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None + ).unwrap(); + bind(rsock, &sock_addr).unwrap(); + let ssock = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + ).expect("send socket failed"); + + let send_thread = thread::spawn(move || { + for _ in 0..2 { + sendto(ssock, &b"12"[..], &sock_addr, MsgFlags::empty()).unwrap(); + } + }); + + let mut msgs = std::collections::LinkedList::new(); + let mut rec_buf1 = [0u8; 32]; + let mut rec_buf2 = [0u8; 32]; + let iov1 = [IoVec::from_mut_slice(&mut rec_buf1[..])]; + let iov2 = [IoVec::from_mut_slice(&mut rec_buf2[..])]; + msgs.push_back( + RecvMmsgData { + iov: &iov1, + cmsg_buffer: None, + } + ); + msgs.push_back( + RecvMmsgData { + iov: &iov2, + cmsg_buffer: None, + } + ); + let res = recvmmsg(rsock, &mut msgs, MsgFlags::empty(), None).expect("recvmmsg"); + assert_eq!(res.len(), 2); + + for RecvMsg { address, bytes, .. } in res.into_iter() { + assert_eq!(AddressFamily::Inet, address.unwrap().family()); + assert_eq!(2, bytes); + } + assert_eq!(&rec_buf1[..2], b"12"); + assert_eq!(&rec_buf2[..2], b"12"); + + send_thread.join().unwrap(); + } } // Test error handling of our recvmsg wrapper