Skip to content

Commit

Permalink
Add support for Packet MMAP
Browse files Browse the repository at this point in the history
  • Loading branch information
DBLouis committed Mar 24, 2024
1 parent 4ed2ea2 commit d1ddfc8
Show file tree
Hide file tree
Showing 15 changed files with 1,260 additions and 14 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
libc = { version = "0.2.152", default-features = false, features = ["extra_traits"], optional = true }

Expand All @@ -53,7 +53,7 @@ libc = { version = "0.2.152", default-features = false, features = ["extra_trait
# Some syscalls do not have libc wrappers, such as in `io_uring`. For these,
# the libc backend uses the linux-raw-sys ABI and `libc::syscall`.
[target.'cfg(all(any(target_os = "android", target_os = "linux"), any(rustix_use_libc, miri, not(all(target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64")))))))'.dependencies]
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "ioctl", "no_std"] }
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "ioctl", "no_std"] }

# For the libc backend on Windows, use the Winsock API in windows-sys.
[target.'cfg(windows)'.dependencies.windows-sys]
Expand Down Expand Up @@ -141,7 +141,7 @@ io_uring = ["event", "fs", "net", "linux-raw-sys/io_uring"]
mount = []

# Enable `rustix::net::*`.
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/xdp"]
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/if_packet", "linux-raw-sys/xdp"]

# Enable `rustix::thread::*`.
thread = ["linux-raw-sys/prctl"]
Expand Down
366 changes: 366 additions & 0 deletions examples/packet/inner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
use rustix::event::{poll, PollFd, PollFlags};
use rustix::fd::OwnedFd;
use rustix::mm::{mmap, munmap, MapFlags, ProtFlags};
use rustix::net::{
bind_link, eth,
netdevice::name_to_index,
packet::{PacketHeader2, PacketReq, PacketReqAny, PacketStatus, SocketAddrLink},
send, socket_with,
sockopt::{set_packet_rx_ring, set_packet_tx_ring, set_packet_version, PacketVersion},
AddressFamily, SendFlags, SocketFlags, SocketType,
};
use std::{cell::Cell, collections::VecDeque, env, ffi::c_void, io, ptr, slice, str};

#[derive(Debug)]
pub struct Socket {
fd: OwnedFd,
block_size: usize,
block_count: usize,
frame_size: usize,
frame_count: usize,
rx: Cell<*mut c_void>,
tx: Cell<*mut c_void>,
}

impl Socket {
fn new(
name: &str,
block_size: usize,
block_count: usize,
frame_size: usize,
) -> io::Result<Self> {
let family = AddressFamily::PACKET;
let type_ = SocketType::RAW;
let flags = SocketFlags::empty();
let fd = socket_with(family, type_, flags, None)?;

let index = name_to_index(&fd, name)?;

set_packet_version(&fd, PacketVersion::V2)?;

let frame_count = (block_size * block_count) / frame_size;
let req = PacketReq {
block_size: block_size as u32,
block_nr: block_count as u32,
frame_size: frame_size as u32,
frame_nr: frame_count as u32,
};

let req = PacketReqAny::V2(req);
set_packet_rx_ring(&fd, &req)?;
set_packet_tx_ring(&fd, &req)?;

let addr = SocketAddrLink::new(eth::ALL, index);
bind_link(&fd, &addr)?;

let rx = unsafe {
mmap(
ptr::null_mut(),
block_size * block_count * 2,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
&fd,
0,
)
}?;
let tx = unsafe { rx.add(block_size * block_count) };

Ok(Self {
fd,
block_size,
block_count,
frame_size,
frame_count,
rx: Cell::new(rx),
tx: Cell::new(tx),
})
}

/// Returns a reader object for receiving packets.
pub fn reader(&self) -> Reader<'_> {
assert!(!self.rx.get().is_null());
Reader {
socket: self,
// Take ring pointer.
ring: self.rx.replace(ptr::null_mut()),
}
}

/// Returns a writer object for transmitting packets.
pub fn writer(&self) -> Writer<'_> {
assert!(!self.tx.get().is_null());
Writer {
socket: self,
// Take ring pointer.
ring: self.tx.replace(ptr::null_mut()),
}
}

/// Flushes the transmit buffer.
pub fn flush(&self) -> io::Result<()> {
send(&self.fd, &[], SendFlags::empty())?;
Ok(())
}
}

impl Drop for Socket {
fn drop(&mut self) {
debug_assert!(!self.rx.get().is_null());
debug_assert!(!self.tx.get().is_null());
unsafe {
let _ = munmap(self.rx.get(), self.block_size * self.block_count * 2);
}
}
}

/// TODO
#[derive(Debug)]
pub struct Packet<'r> {
header: &'r mut PacketHeader2,
}

impl<'r> Packet<'r> {
pub fn payload(&self) -> &[u8] {
let ptr = self.header.payload_rx();
let len = self.header.len as usize;
unsafe { slice::from_raw_parts(ptr, len) }
}
}

impl<'r> Drop for Packet<'r> {
fn drop(&mut self) {
self.header.status = PacketStatus::empty();
}
}

/// TODO
#[derive(Debug)]
pub struct Slot<'w> {
header: &'w mut PacketHeader2,
}

impl<'w> Slot<'w> {
pub fn write(&mut self, payload: &[u8]) {
let ptr = self.header.payload_tx();
// TODO verify length
let len = payload.len();
unsafe {
ptr.copy_from_nonoverlapping(payload.as_ptr(), len);
self.header.len = len as u32;
}
}
}

impl<'w> Drop for Slot<'w> {
fn drop(&mut self) {
self.header.status = PacketStatus::SEND_REQUEST;
}
}

/// A reader object for receiving packets.
#[derive(Debug)]
pub struct Reader<'s> {
socket: &'s Socket,
ring: *mut c_void, // Owned
}

impl<'s> Reader<'s> {
/// Returns an iterator over received packets.
/// The iterator blocks until at least one packet is received.
///
/// # Lifetimes
///
/// - `'s`: The lifetime of the socket.
/// - `'r`: The lifetime of the received packets.
pub fn wait<'r>(&'r mut self) -> io::Result<ReadIter<'s, 'r>>
where
's: 'r,
{
let flags = PollFlags::IN | PollFlags::RDNORM | PollFlags::ERR;
let pfd = PollFd::new(&self.socket.fd, flags);
let pfd = &mut [pfd];
let n = poll(pfd, -1)?;
assert_eq!(n, 1);
Ok(ReadIter {
reader: self,
index: 0,
})
}
}

impl<'s> Drop for Reader<'s> {
fn drop(&mut self) {
// Give back ring pointer.
self.socket.rx.set(self.ring);
}
}

/// A writer object for transmitting packets.
#[derive(Debug)]
pub struct Writer<'s> {
socket: &'s Socket,
ring: *mut c_void, // Owned
}

impl<'s> Writer<'s> {
/// Returns an iterator over available slots for transmitting packets.
/// The iterator blocks until at least one slot is available.
///
/// # Lifetimes
///
/// - `'s`: The lifetime of the socket.
/// - `'w`: The lifetime of the slots.
pub fn wait<'w>(&'w mut self) -> io::Result<WriteIter<'s, 'w>>
where
's: 'w,
{
let flags = PollFlags::OUT | PollFlags::WRNORM | PollFlags::ERR;
let pfd = PollFd::new(&self.socket.fd, flags);
let pfd = &mut [pfd];
let n = poll(pfd, -1)?;
assert_eq!(n, 1);
Ok(WriteIter {
writer: self,
index: 0,
})
}
}

impl<'s> Drop for Writer<'s> {
fn drop(&mut self) {
// Give back ring pointer.
self.socket.tx.set(self.ring);
}
}

/// An iterator over received packets.
#[derive(Debug)]
pub struct ReadIter<'s, 'r> {
reader: &'r mut Reader<'s>,
index: usize,
}

impl<'s, 'r> Iterator for ReadIter<'s, 'r> {
type Item = Packet<'r>;

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.reader.socket.frame_count {
let base = unsafe {
self.reader
.ring
.add(self.index * self.reader.socket.frame_size)
};
self.index += 1;

if let Some(header) = unsafe { PacketHeader2::from_rx_ptr(base) } {
return Some(Packet { header });
}
}
None
}
}

/// An iterator over available slots for transmitting packets.
#[derive(Debug)]
pub struct WriteIter<'s, 'w> {
writer: &'w mut Writer<'s>,
index: usize,
}

impl<'s, 'w> Iterator for WriteIter<'s, 'w> {
type Item = Slot<'w>;

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.writer.socket.frame_count {
let base = unsafe {
self.writer
.ring
.add(self.index * self.writer.socket.frame_size)
};
self.index += 1;

if let Some(header) = unsafe { PacketHeader2::from_tx_ptr(base) } {
return Some(Slot { header });
}
}
None
}
}

// ECHO server
fn server(socket: Socket, mut count: usize) -> io::Result<()> {
let mut reader = socket.reader();
let mut writer = socket.writer();

while count > 0 {
let mut queue = VecDeque::new();

for packet in reader.wait()? {
queue.push_back(packet);
}

while let Some(packet) = queue.pop_front() {
let mut iter = writer.wait()?.take(count);
while let Some(mut slot) = iter.next() {
let mut payload = packet.payload().to_vec();
assert_eq!(payload[12..14], [0x08, 0x00]);
payload.swap(14, 15);

slot.write(&payload);
drop(slot);
count -= 1;
}
drop(packet);
}

socket.flush()?;
}

Ok(())
}

// ECHO client
fn client(socket: Socket, mut count: usize) -> io::Result<()> {
let mut reader = socket.reader();
let mut writer = socket.writer();

while count > 0 {
let mut iter = writer.wait()?.take(count);
while let Some(mut slot) = iter.next() {
let payload = &[
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Destination
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Source
0x08, 0x00, // Type (IPv4, but not really)
0x13, 0x37, // Payload (some value)
];

slot.write(payload);
drop(slot);
count -= 1;
}

socket.flush()?;

for packet in reader.wait()? {
assert_eq!(packet.payload()[14..16], [0x37, 0x13]);
}
}

Ok(())
}

pub fn main() -> io::Result<()> {
let mut args = env::args().skip(1);
let name = args.next().expect("name");
let mode = args.next().expect("mode");
let count = args.next().expect("count");

let socket = Socket::new(&name, 4096, 4, 2048)?;
let count = count.parse().unwrap();

match mode.as_str() {
"server" => server(socket, count),
"client" => client(socket, count),
_ => panic!("invalid mode"),
}
}

0 comments on commit d1ddfc8

Please sign in to comment.