Skip to content

Commit

Permalink
Add support for mbind, get_mempolicy and set_mempolicy (bytecod…
Browse files Browse the repository at this point in the history
…ealliance#937)

This adds support for the `mbind`, `set_mempolicy` and `get_mempolicy`
NUMA syscalls.  The `get_mempolicy` syscall has a few different modes
of operation, depending on the flags, which is demultiplexed into
`get_mempolicy_node` and `get_mempolicy_next_node` for now.  There's a
couple of other modes that writes into the variable length bit array,
which aren't implemented for now.
  • Loading branch information
krh committed Nov 30, 2023
1 parent 3056dec commit 4c5f6ce
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
linux-raw-sys = { version = "0.4.11", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
linux-raw-sys = { version = "0.6.2", default-features = false, features = ["general", "errno", "ioctl", "mempolicy", "no_std", "elf"] }
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
libc = { version = "0.2.150", default-features = false, features = ["extra_traits"], optional = true }

Expand Down Expand Up @@ -170,6 +170,9 @@ termios = []
# Enable `rustix::mm::*`.
mm = []

# Enable `rustix::numa::*`.
numa = []

# Enable `rustix::pipe::*`.
pipe = []

Expand All @@ -194,6 +197,7 @@ all-apis = [
"mm",
"mount",
"net",
"numa",
"param",
"pipe",
"process",
Expand Down
16 changes: 16 additions & 0 deletions src/backend/linux_raw/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From<Option<crate::net::Protocol>> for ArgReg<'a, Num>
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::Mode> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::Mode) -> Self {
c_uint(flags.bits())
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::ModeFlags> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::ModeFlags) -> Self {
c_uint(flags.bits())
}
}

impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit<T>> for ArgReg<'a, Num> {
#[inline]
fn from(t: &'a mut MaybeUninit<T>) -> Self {
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub(crate) mod mount;
pub(crate) mod mount; // for deprecated mount functions in "fs"
#[cfg(feature = "net")]
pub(crate) mod net;
#[cfg(feature = "numa")]
pub(crate) mod numa;
#[cfg(any(
feature = "param",
feature = "process",
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod syscalls;
pub(crate) mod types;
92 changes: 92 additions & 0 deletions src/backend/linux_raw/numa/syscalls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//! linux_raw syscalls supporting `rustix::numa`.
//!
//! # Safety
//!
//! See the `rustix::backend` module documentation for details.

#![allow(unsafe_code)]
#![allow(clippy::undocumented_unsafe_blocks)]

use super::types::{Mode, ModeFlags};

use crate::backend::c;
use crate::backend::conv::{c_uint, pass_usize, ret, zero};
use crate::io;
use core::mem::MaybeUninit;

/// # Safety
///
/// `mbind` is primarily unsafe due to the `addr` parameter, as anything
/// working with memory pointed to by raw pointers is unsafe.
#[inline]
pub(crate) unsafe fn mbind(
addr: *mut c::c_void,
length: usize,
mode: Mode,
nodemask: &[u64],
flags: ModeFlags,
) -> io::Result<()> {
ret(syscall!(
__NR_mbind,
addr,
pass_usize(length),
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize),
flags
))
}

/// # Safety
///
/// `set_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
ret(syscall!(
__NR_set_mempolicy,
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize)
))
}

/// # Safety
///
/// `get_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
addr,
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE | linux_raw_sys::mempolicy::MPOL_F_ADDR)
))?;

Ok(mode.assume_init())
}

#[inline]
pub(crate) fn get_mempolicy_next_node() -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

unsafe {
ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
zero(),
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE)
))?;

Ok(mode.assume_init())
}
}
52 changes: 52 additions & 0 deletions src/backend/linux_raw/numa/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use bitflags::bitflags;

bitflags! {
/// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Mode: u32 {
/// `MPOL_F_STATIC_NODES`
const STATIC_NODES = linux_raw_sys::mempolicy::MPOL_F_STATIC_NODES;
/// `MPOL_F_RELATIVE_NODES`
const RELATIVE_NODES = linux_raw_sys::mempolicy::MPOL_F_RELATIVE_NODES;
/// `MPOL_F_NUMA_BALANCING`
const NUMA_BALANCING = linux_raw_sys::mempolicy::MPOL_F_NUMA_BALANCING;

/// `MPOL_DEFAULT`
const DEFAULT = linux_raw_sys::mempolicy::MPOL_DEFAULT as u32;
/// `MPOL_PREFERRED`
const PREFERRED = linux_raw_sys::mempolicy::MPOL_PREFERRED as u32;
/// `MPOL_BIND`
const BIND = linux_raw_sys::mempolicy::MPOL_BIND as u32;
/// `MPOL_INTERLEAVE`
const INTERLEAVE = linux_raw_sys::mempolicy::MPOL_INTERLEAVE as u32;
/// `MPOL_LOCAL`
const LOCAL = linux_raw_sys::mempolicy::MPOL_LOCAL as u32;
/// `MPOL_PREFERRED_MANY`
const PREFERRED_MANY = linux_raw_sys::mempolicy::MPOL_PREFERRED_MANY as u32;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}

bitflags! {
/// `MPOL_MF_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ModeFlags: u32 {
/// `MPOL_MF_STRICT`
const STRICT = linux_raw_sys::mempolicy::MPOL_MF_STRICT;
/// `MPOL_MF_MOVE`
const MOVE = linux_raw_sys::mempolicy::MPOL_MF_MOVE;
/// `MPOL_MF_MOVE_ALL`
const MOVE_ALL = linux_raw_sys::mempolicy::MPOL_MF_MOVE_ALL;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ pub mod mount;
#[cfg(feature = "net")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "net")))]
pub mod net;
#[cfg(linux_kernel)]
#[cfg(feature = "numa")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))]
pub mod numa;
#[cfg(not(any(windows, target_os = "espidf")))]
#[cfg(feature = "param")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "param")))]
Expand Down
108 changes: 108 additions & 0 deletions src/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
//! The `numa` API.
//!
//! # Safety
//!
//! `mbind` and related functions manipulate raw pointers and have special
//! semantics and are wildly unsafe.
#![allow(unsafe_code)]

use crate::{backend, io};
use core::ffi::c_void;

pub use backend::numa::types::{Mode, ModeFlags};

/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn mbind(
addr: *mut c_void,
len: usize,
mode: Mode,
nodemask: &[u64],
flags: ModeFlags,
) -> io::Result<()> {
backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags)
}

/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for
/// a thread and its children.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
backend::numa::syscalls::set_mempolicy(mode, nodemask)
}

/// `get_mempolicy_node(addr)`-Return the node ID of the node on which
/// the address addr is allocated.
///
/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
/// get_mempolicy() will return the node ID of the node on which the
/// address addr is allocated into the location pointed to by mode.
/// If no page has yet been allocated for the specified address,
/// get_mempolicy() will allocate a page as if the thread had
/// performed a read (load) access to that address, and return the ID
/// of the node where that page was allocated.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_node(addr)
}

/// `get_mempolicy_next_node(addr)`-Return node ID of the next node
/// that will be used for interleaving of internal kernel pages
/// allocated on behalf of the thread.
///
/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the
/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy()
/// will return in the location pointed to by a non-NULL mode
/// argument, the node ID of the next node that will be used for
/// interleaving of internal kernel pages allocated on behalf of the
/// thread. These allocations include pages for memory-mapped files
/// in process memory ranges mapped using the mmap(2) call with the
/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped
/// with the MAP_SHARED flag for all accesses.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_next_node() -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_next_node()
}
40 changes: 40 additions & 0 deletions tests/numa/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#[cfg(all(feature = "mm", feature = "fs"))]
#[test]
fn test_mbind() {
let size = 8192;

unsafe {
let vaddr = rustix::mm::mmap_anonymous(
std::ptr::null_mut(),
size,
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
rustix::mm::MapFlags::PRIVATE,
)
.unwrap();

vaddr.cast::<usize>().write(100);

let mask = &[1];
rustix::numa::mbind(
vaddr,
size,
rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
mask,
rustix::numa::ModeFlags::empty(),
)
.unwrap();

rustix::numa::get_mempolicy_node(vaddr).unwrap();

match rustix::numa::get_mempolicy_next_node() {
Err(rustix::io::Errno::INVAL) => (),
_ => panic!(
"rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT"
),
}

rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap();

rustix::numa::get_mempolicy_next_node().unwrap();
}
}

0 comments on commit 4c5f6ce

Please sign in to comment.