diff --git a/Cargo.toml b/Cargo.toml index a55c548d..5d0325f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ itertools = "0.10" libc = "0.2" maplit = "1.0" md5 = "0.7" -nix = "0.25" +nix = "0.26" num-derive = "0.3" num-traits = "0.2" prost = "0.11" diff --git a/holo-bfd/src/error.rs b/holo-bfd/src/error.rs index be4a5b5b..75717af6 100644 --- a/holo-bfd/src/error.rs +++ b/holo-bfd/src/error.rs @@ -37,6 +37,8 @@ pub enum IoError { UdpSocketError(std::io::Error), UdpRecvError(std::io::Error), UdpSendError(std::io::Error), + UdpRecvMissingSourceAddr, + UdpRecvMissingAncillaryData, } // ===== impl Error ===== @@ -145,6 +147,10 @@ impl IoError { | IoError::UdpSendError(error) => { warn!(error = %with_source(error), "{}", self); } + IoError::UdpRecvMissingSourceAddr + | IoError::UdpRecvMissingAncillaryData => { + warn!("{}", self); + } } } } @@ -161,6 +167,18 @@ impl std::fmt::Display for IoError { IoError::UdpSendError(..) => { write!(f, "failed to send UDP packet") } + IoError::UdpRecvMissingSourceAddr => { + write!( + f, + "failed to retrieve source address from received packet" + ) + } + IoError::UdpRecvMissingAncillaryData => { + write!( + f, + "failed to retrieve ancillary data from received packet" + ) + } } } } @@ -171,6 +189,7 @@ impl std::error::Error for IoError { IoError::UdpSocketError(error) | IoError::UdpRecvError(error) | IoError::UdpSendError(error) => Some(error), + _ => None, } } } diff --git a/holo-bfd/src/network.rs b/holo-bfd/src/network.rs index e9c1ef07..8e6d87d0 100644 --- a/holo-bfd/src/network.rs +++ b/holo-bfd/src/network.rs @@ -4,7 +4,12 @@ // See LICENSE for license details. // -use std::net::{IpAddr, SocketAddr}; +use std::io::IoSliceMut; +use std::net::{ + IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, +}; +use std::ops::Deref; +use std::os::fd::AsRawFd; use std::sync::atomic::{self, AtomicU64}; use std::sync::Arc; @@ -12,6 +17,7 @@ use holo_utils::bfd::PathType; use holo_utils::ip::{AddressFamily, IpAddrExt}; use holo_utils::socket::{UdpSocket, UdpSocketExt}; use holo_utils::{capabilities, Sender}; +use nix::sys::socket::{self, ControlMessageOwned}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::error::SendError; @@ -52,9 +58,11 @@ pub(crate) fn socket_rx( match path_type { PathType::IpSingleHop => match af { AddressFamily::Ipv4 => { + socket.set_ipv4_pktinfo(true)?; socket.set_ipv4_minttl(255)?; } AddressFamily::Ipv6 => { + socket.set_ipv6_pktinfo(true)?; socket.set_min_hopcount_v6(255)?; } }, @@ -63,6 +71,14 @@ pub(crate) fn socket_rx( // sessions, incoming TTL checking should be done in the // userspace given that different peers might have different TTL // settings. + match af { + AddressFamily::Ipv4 => { + socket.set_ipv4_pktinfo(true)?; + } + AddressFamily::Ipv6 => { + socket.set_ipv6_pktinfo(true)?; + } + } } } @@ -144,6 +160,38 @@ pub(crate) async fn send_packet( } } +#[cfg(not(feature = "testing"))] +fn get_packet_src(sa: Option<&socket::SockaddrStorage>) -> Option { + sa.and_then(|sa| { + if let Some(sa) = sa.as_sockaddr_in() { + Some(SocketAddrV4::from(*sa).into()) + } else if let Some(sa) = sa.as_sockaddr_in6() { + Some(SocketAddrV6::from(*sa).into()) + } else { + None + } + }) +} + +#[cfg(not(feature = "testing"))] +fn get_packet_dst(cmsgs: socket::CmsgIterator<'_>) -> Option { + for cmsg in cmsgs { + match cmsg { + ControlMessageOwned::Ipv4PacketInfo(pktinfo) => { + return Some( + Ipv4Addr::from(pktinfo.ipi_spec_dst.s_addr.to_be()).into(), + ); + } + ControlMessageOwned::Ipv6PacketInfo(pktinfo) => { + return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into()); + } + _ => {} + } + } + + None +} + #[cfg(not(feature = "testing"))] pub(crate) async fn read_loop( socket: Arc, @@ -151,47 +199,81 @@ pub(crate) async fn read_loop( udp_packet_rxp: Sender, ) -> Result<(), SendError> { let mut buf = [0; 1024]; + let mut iov = [IoSliceMut::new(&mut buf)]; + let mut cmsgspace = nix::cmsg_space!(libc::in6_pktinfo); loop { // Receive data from the network. - let (_, src) = match socket.recv_from(&mut buf).await { - Ok((num_bytes, src)) => (num_bytes, src), + match socket + .async_io(tokio::io::Interest::READABLE, || { + match socket::recvmsg::( + socket.as_raw_fd(), + &mut iov, + Some(&mut cmsgspace), + socket::MsgFlags::empty(), + ) { + Ok(msg) => { + // Retrieve source and destination addresses. + let src = get_packet_src(msg.address.as_ref()); + let dst = get_packet_dst(msg.cmsgs()); + Ok((src, dst, msg.bytes)) + } + Err(errno) => Err(errno.into()), + } + }) + .await + { + Ok((src, dst, bytes)) => { + let src = match src { + Some(addr) => addr, + None => { + IoError::UdpRecvMissingSourceAddr.log(); + return Ok(()); + } + }; + let dst = match dst { + Some(addr) => addr, + None => { + IoError::UdpRecvMissingAncillaryData.log(); + return Ok(()); + } + }; + + // Validate packet's source address. + if !src.ip().is_usable() { + Error::UdpInvalidSourceAddr(src.ip()).log(); + continue; + } + + // Decode packet, discarding malformed ones. + let packet = match Packet::decode(&iov[0].deref()[0..bytes]) { + Ok(packet) => packet, + Err(_) => continue, + }; + + // Notify the BFD main task about the received packet. + let packet_info = match path_type { + PathType::IpSingleHop => PacketInfo::IpSingleHop { src }, + PathType::IpMultihop => { + let src = src.ip(); + // TODO: get packet's TTL using IP_RECVTTL/IPV6_HOPLIMIT + let ttl = 255; + PacketInfo::IpMultihop { src, dst, ttl } + } + }; + let msg = UdpRxPacketMsg { + packet_info, + packet, + }; + udp_packet_rxp.send(msg).await?; + } + Err(error) if error.kind() == std::io::ErrorKind::Interrupted => { + // Retry if the syscall was interrupted (EINTR). + continue; + } Err(error) => { IoError::UdpRecvError(error).log(); - continue; } - }; - - // Validate packet's source address. - if !src.ip().is_usable() { - Error::UdpInvalidSourceAddr(src.ip()).log(); - continue; } - - // Get packet's ancillary data. - let packet_info = match path_type { - PathType::IpSingleHop => PacketInfo::IpSingleHop { src }, - PathType::IpMultihop => { - let src = src.ip(); - // TODO: get packet's destination using IP_PKTINFO/IPV6_PKTINFO. - let dst = src; - // TODO: get packet's TTL using IP_RECVTTL/IPV6_HOPLIMIT. - let ttl = 255; - PacketInfo::IpMultihop { src, dst, ttl } - } - }; - - // Decode packet, dropping malformed ones. - let packet = match Packet::decode(&buf) { - Ok(packet) => packet, - Err(_) => continue, - }; - - // Notify the BFD main task about the received packet. - let msg = UdpRxPacketMsg { - packet_info, - packet, - }; - udp_packet_rxp.send(msg).await?; } } diff --git a/holo-utils/src/socket.rs b/holo-utils/src/socket.rs index ac9c26e9..473cd26c 100644 --- a/holo-utils/src/socket.rs +++ b/holo-utils/src/socket.rs @@ -92,6 +92,12 @@ pub trait UdpSocketExt { // Sets the value of the IPV6_MINHOPCOUNT option for this socket. fn set_min_hopcount_v6(&self, hopcount: u8) -> Result<()>; + + // Sets the value of the IP_PKTINFO option for this socket. + fn set_ipv4_pktinfo(&self, value: bool) -> Result<()>; + + // Sets the value of the IPV6_RECVPKTINFO option for this socket. + fn set_ipv6_pktinfo(&self, value: bool) -> Result<()>; } // Extension methods for TcpSocket. @@ -331,6 +337,30 @@ impl UdpSocketExt for UdpSocket { std::mem::size_of::() as libc::socklen_t, ) } + + fn set_ipv4_pktinfo(&self, value: bool) -> Result<()> { + let optval = value as c_int; + + setsockopt( + self, + libc::IPPROTO_IP, + libc::IP_PKTINFO, + &optval as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ) + } + + fn set_ipv6_pktinfo(&self, value: bool) -> Result<()> { + let optval = value as c_int; + + setsockopt( + self, + libc::IPPROTO_IPV6, + libc::IPV6_RECVPKTINFO, + &optval as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ) + } } // ===== impl TcpSocket =====