Skip to content

Commit

Permalink
linux: use poll api instead of select inorder to support fd > 1024
Browse files Browse the repository at this point in the history
  • Loading branch information
nemosupremo committed Apr 19, 2024
1 parent 87f362d commit 48f2109
Showing 1 changed file with 76 additions and 52 deletions.
128 changes: 76 additions & 52 deletions pnet_datalink/src/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ use crate::{DataLinkReceiver, DataLinkSender, MacAddr, NetworkInterface};

use pnet_sys;

use std::cmp;
use std::io;
use std::mem;
use std::ptr;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -200,7 +198,6 @@ pub fn channel(network_interface: &NetworkInterface, config: Config) -> io::Resu
let fd = Arc::new(pnet_sys::FileDesc { fd: socket });
let sender = Box::new(DataLinkSenderImpl {
socket: fd.clone(),
fd_set: unsafe { mem::zeroed() },
write_buffer: vec![0; config.write_buffer_size],
_channel_type: config.channel_type,
send_addr: unsafe { *(send_addr as *const libc::sockaddr_ll) },
Expand All @@ -211,7 +208,6 @@ pub fn channel(network_interface: &NetworkInterface, config: Config) -> io::Resu
});
let receiver = Box::new(DataLinkReceiverImpl {
socket: fd.clone(),
fd_set: unsafe { mem::zeroed() },
read_buffer: vec![0; config.read_buffer_size],
_channel_type: config.channel_type,
timeout: config
Expand All @@ -224,7 +220,6 @@ pub fn channel(network_interface: &NetworkInterface, config: Config) -> io::Resu

struct DataLinkSenderImpl {
socket: Arc<pnet_sys::FileDesc>,
fd_set: libc::fd_set,
write_buffer: Vec<u8>,
_channel_type: super::ChannelType,
send_addr: libc::sockaddr_ll,
Expand All @@ -243,35 +238,40 @@ impl DataLinkSender for DataLinkSenderImpl {
) -> Option<io::Result<()>> {
let len = num_packets * packet_size;
if len <= self.write_buffer.len() {
let min = cmp::min(self.write_buffer[..].len(), len);
let min = std::cmp::min(self.write_buffer.len(), len);
let mut_slice = &mut self.write_buffer;

let mut pollfd = libc::pollfd {
fd: self.socket.fd,
events: libc::POLLOUT, // Monitoring for write ability
revents: 0, // Will be filled by poll to indicate the events that occurred
};

// Convert timeout to milliseconds as required by poll
let timeout_ms = self
.timeout
.as_ref()
.map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000))
.unwrap_or(-1); // -1 means wait indefinitely

for chunk in mut_slice[..min].chunks_mut(packet_size) {
func(chunk);
let send_addr =
(&self.send_addr as *const libc::sockaddr_ll) as *const libc::sockaddr;

unsafe {
libc::FD_ZERO(&mut self.fd_set as *mut libc::fd_set);
libc::FD_SET(self.socket.fd, &mut self.fd_set as *mut libc::fd_set);
}
let ret = unsafe {
libc::pselect(
self.socket.fd + 1,
ptr::null_mut(),
&mut self.fd_set as *mut libc::fd_set,
ptr::null_mut(),
self.timeout
.as_ref()
.map(|to| to as *const libc::timespec)
.unwrap_or(ptr::null()),
ptr::null(),
libc::poll(
&mut pollfd as *mut libc::pollfd,
1,
timeout_ms as libc::c_int,
)
};

if ret == -1 {
return Some(Err(io::Error::last_os_error()));
} else if ret == 0 {
return Some(Err(io::Error::new(io::ErrorKind::TimedOut, "Timed out")));
} else {
} else if pollfd.revents & libc::POLLOUT != 0 {
if let Err(e) = pnet_sys::send_to(
self.socket.fd,
chunk,
Expand All @@ -280,6 +280,11 @@ impl DataLinkSender for DataLinkSenderImpl {
) {
return Some(Err(e));
}
} else {
return Some(Err(io::Error::new(
io::ErrorKind::Other,
"Unexpected poll event",
)));
}
}

Expand All @@ -291,28 +296,33 @@ impl DataLinkSender for DataLinkSenderImpl {

#[inline]
fn send_to(&mut self, packet: &[u8], _dst: Option<NetworkInterface>) -> Option<io::Result<()>> {
unsafe {
libc::FD_ZERO(&mut self.fd_set as *mut libc::fd_set);
libc::FD_SET(self.socket.fd, &mut self.fd_set as *mut libc::fd_set);
}
let mut pollfd = libc::pollfd {
fd: self.socket.fd,
events: libc::POLLOUT, // Monitoring for write ability
revents: 0, // Will be filled by poll to indicate the events that occurred
};

// Convert timeout to milliseconds as required by poll
let timeout_ms = self
.timeout
.as_ref()
.map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000))
.unwrap_or(-1); // -1 means wait indefinitely

let ret = unsafe {
libc::pselect(
self.socket.fd + 1,
ptr::null_mut(),
&mut self.fd_set as *mut libc::fd_set,
ptr::null_mut(),
self.timeout
.as_ref()
.map(|to| to as *const libc::timespec)
.unwrap_or(ptr::null()),
ptr::null(),
libc::poll(
&mut pollfd as *mut libc::pollfd,
1,
timeout_ms as libc::c_int,
)
};

if ret == -1 {
Some(Err(io::Error::last_os_error()))
} else if ret == 0 {
Some(Err(io::Error::new(io::ErrorKind::TimedOut, "Timed out")))
} else {
} else if pollfd.revents & libc::POLLOUT != 0 {
// POLLOUT is set, meaning the socket is ready for writing
match pnet_sys::send_to(
self.socket.fd,
packet,
Expand All @@ -322,13 +332,17 @@ impl DataLinkSender for DataLinkSenderImpl {
Err(e) => Some(Err(e)),
Ok(_) => Some(Ok(())),
}
} else {
Some(Err(io::Error::new(
io::ErrorKind::Other,
"Unexpected poll event",
)))
}
}
}

struct DataLinkReceiverImpl {
socket: Arc<pnet_sys::FileDesc>,
fd_set: libc::fd_set,
read_buffer: Vec<u8>,
_channel_type: super::ChannelType,
timeout: Option<libc::timespec>,
Expand All @@ -337,33 +351,43 @@ struct DataLinkReceiverImpl {
impl DataLinkReceiver for DataLinkReceiverImpl {
fn next(&mut self) -> io::Result<&[u8]> {
let mut caddr: libc::sockaddr_storage = unsafe { mem::zeroed() };
unsafe {
libc::FD_ZERO(&mut self.fd_set as *mut libc::fd_set);
libc::FD_SET(self.socket.fd, &mut self.fd_set as *mut libc::fd_set);
}
let mut pollfd = libc::pollfd {
fd: self.socket.fd,
events: libc::POLLIN, // Monitoring for read availability
revents: 0,
};

// Convert timeout to milliseconds as required by poll
let timeout_ms = self
.timeout
.as_ref()
.map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000))
.unwrap_or(-1); // -1 means wait indefinitely

let ret = unsafe {
libc::pselect(
self.socket.fd + 1,
&mut self.fd_set as *mut libc::fd_set,
ptr::null_mut(),
ptr::null_mut(),
self.timeout
.as_ref()
.map(|to| to as *const libc::timespec)
.unwrap_or(ptr::null()),
ptr::null(),
libc::poll(
&mut pollfd as *mut libc::pollfd,
1,
timeout_ms as libc::c_int,
)
};

if ret == -1 {
Err(io::Error::last_os_error())
} else if ret == 0 {
Err(io::Error::new(io::ErrorKind::TimedOut, "Timed out"))
} else {
} else if pollfd.revents & libc::POLLIN != 0 {
// POLLIN is set, meaning the socket has data to be read
let res = pnet_sys::recv_from(self.socket.fd, &mut self.read_buffer, &mut caddr);
match res {
Ok(len) => Ok(&self.read_buffer[0..len]),
Err(e) => Err(e),
}
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Unexpected poll event",
))
}
}
}
Expand Down

0 comments on commit 48f2109

Please sign in to comment.