Skip to content

Commit

Permalink
feat: Use type parameters to allow {get,set}regset to use different…
Browse files Browse the repository at this point in the history
… register set structs (#2373)
  • Loading branch information
hack3ric committed Apr 25, 2024
1 parent 395906e commit 213127b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 38 deletions.
111 changes: 85 additions & 26 deletions src/sys/ptrace/linux.rs
Expand Up @@ -172,21 +172,21 @@ libc_enum! {
}
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
libc_enum! {
#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
#[repr(i32)]
/// Defining a specific register set, as used in [`getregset`] and [`setregset`].
/// Defines a specific register set, as used in `PTRACE_GETREGSET` and `PTRACE_SETREGSET`.
#[non_exhaustive]
pub enum RegisterSet {
pub enum RegisterSetValue {
NT_PRSTATUS,
NT_PRFPREG,
NT_PRPSINFO,
Expand All @@ -195,6 +195,69 @@ libc_enum! {
}
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
/// Represents register set areas, such as general-purpose registers or
/// floating-point registers.
///
/// # Safety
///
/// This trait is marked unsafe, since implementation of the trait must match
/// ptrace's request `VALUE` and return data type `Regs`.
pub unsafe trait RegisterSet {
/// Corresponding type of registers in the kernel.
const VALUE: RegisterSetValue;

/// Struct representing the register space.
type Regs;
}

#[cfg(all(
target_os = "linux",
target_env = "gnu",
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "riscv64",
)
))]
/// Register sets used in [`getregset`] and [`setregset`]
pub mod regset {
use super::*;

#[derive(Debug, Clone, Copy)]
/// General-purpose registers.
pub struct NT_PRSTATUS;

unsafe impl RegisterSet for NT_PRSTATUS {
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRSTATUS;
type Regs = user_regs_struct;
}

#[derive(Debug, Clone, Copy)]
/// Floating-point registers.
pub struct NT_PRFPREG;

unsafe impl RegisterSet for NT_PRFPREG {
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRFPREG;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
type Regs = libc::user_fpregs_struct;
#[cfg(target_arch = "aarch64")]
type Regs = libc::user_fpsimd_struct;
#[cfg(target_arch = "riscv64")]
type Regs = libc::__riscv_mc_d_ext_state;
}
}

libc_bitflags! {
/// Ptrace options used in conjunction with the PTRACE_SETOPTIONS request.
/// See `man ptrace` for more details.
Expand Down Expand Up @@ -275,7 +338,7 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
any(target_arch = "aarch64", target_arch = "riscv64")
))]
pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
getregset(pid, RegisterSet::NT_PRSTATUS)
getregset::<regset::NT_PRSTATUS>(pid)
}

/// Get a particular set of user registers, as with `ptrace(PTRACE_GETREGSET, ...)`
Expand All @@ -289,18 +352,18 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
target_arch = "riscv64",
)
))]
pub fn getregset(pid: Pid, set: RegisterSet) -> Result<user_regs_struct> {
pub fn getregset<S: RegisterSet>(pid: Pid) -> Result<S::Regs> {
let request = Request::PTRACE_GETREGSET;
let mut data = mem::MaybeUninit::<user_regs_struct>::uninit();
let mut data = mem::MaybeUninit::<S::Regs>::uninit();
let mut iov = libc::iovec {
iov_base: data.as_mut_ptr().cast(),
iov_len: mem::size_of::<user_regs_struct>(),
iov_len: mem::size_of::<S::Regs>(),
};
unsafe {
ptrace_other(
request,
pid,
set as i32 as AddressType,
S::VALUE as i32 as AddressType,
(&mut iov as *mut libc::iovec).cast(),
)?;
};
Expand Down Expand Up @@ -349,7 +412,7 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
any(target_arch = "aarch64", target_arch = "riscv64")
))]
pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
setregset(pid, RegisterSet::NT_PRSTATUS, regs)
setregset::<regset::NT_PRSTATUS>(pid, regs)
}

/// Set a particular set of user registers, as with `ptrace(PTRACE_SETREGSET, ...)`
Expand All @@ -363,20 +426,16 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
target_arch = "riscv64",
)
))]
pub fn setregset(
pid: Pid,
set: RegisterSet,
mut regs: user_regs_struct,
) -> Result<()> {
pub fn setregset<S: RegisterSet>(pid: Pid, mut regs: S::Regs) -> Result<()> {
let mut iov = libc::iovec {
iov_base: (&mut regs as *mut user_regs_struct).cast(),
iov_len: mem::size_of::<user_regs_struct>(),
iov_base: (&mut regs as *mut S::Regs).cast(),
iov_len: mem::size_of::<S::Regs>(),
};
unsafe {
ptrace_other(
Request::PTRACE_SETREGSET,
pid,
set as i32 as AddressType,
S::VALUE as i32 as AddressType,
(&mut iov as *mut libc::iovec).cast(),
)?;
}
Expand Down
33 changes: 21 additions & 12 deletions test/sys/test_ptrace.rs
Expand Up @@ -302,7 +302,7 @@ fn test_ptrace_syscall() {
))]
#[test]
fn test_ptrace_regsets() {
use nix::sys::ptrace::{self, getregset, setregset, RegisterSet};
use nix::sys::ptrace::{self, getregset, regset, setregset};
use nix::sys::signal::*;
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::fork;
Expand All @@ -328,30 +328,39 @@ fn test_ptrace_regsets() {
Ok(WaitStatus::Stopped(child, Signal::SIGTRAP))
);
let mut regstruct =
getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
getregset::<regset::NT_PRSTATUS>(child).unwrap();
let mut fpregstruct =
getregset::<regset::NT_PRFPREG>(child).unwrap();

#[cfg(target_arch = "x86_64")]
let reg = &mut regstruct.r15;
let (reg, fpreg) =
(&mut regstruct.r15, &mut fpregstruct.st_space[5]);
#[cfg(target_arch = "x86")]
let reg = &mut regstruct.edx;
let (reg, fpreg) =
(&mut regstruct.edx, &mut fpregstruct.st_space[5]);
#[cfg(target_arch = "aarch64")]
let reg = &mut regstruct.regs[16];
let (reg, fpreg) =
(&mut regstruct.regs[16], &mut fpregstruct.vregs[5]);
#[cfg(target_arch = "riscv64")]
let reg = &mut regstruct.regs[16];
let (reg, fpreg) = (&mut regstruct.t1, &mut fpregstruct.__f[5]);

*reg = 0xdeadbeefu32 as _;
let _ = setregset(child, RegisterSet::NT_PRSTATUS, regstruct);
regstruct = getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
*fpreg = 0xfeedfaceu32 as _;
let _ = setregset::<regset::NT_PRSTATUS>(child, regstruct);
regstruct = getregset::<regset::NT_PRSTATUS>(child).unwrap();
let _ = setregset::<regset::NT_PRFPREG>(child, fpregstruct);
fpregstruct = getregset::<regset::NT_PRFPREG>(child).unwrap();

#[cfg(target_arch = "x86_64")]
let reg = regstruct.r15;
let (reg, fpreg) = (regstruct.r15, fpregstruct.st_space[5]);
#[cfg(target_arch = "x86")]
let reg = regstruct.edx;
let (reg, fpreg) = (regstruct.edx, fpregstruct.st_space[5]);
#[cfg(target_arch = "aarch64")]
let reg = regstruct.regs[16];
let (reg, fpreg) = (regstruct.regs[16], fpregstruct.vregs[5]);
#[cfg(target_arch = "riscv64")]
let reg = regstruct.regs[16];
let (reg, fpreg) = (regstruct.t1, fpregstruct.__f[5]);
assert_eq!(reg, 0xdeadbeefu32 as _);
assert_eq!(fpreg, 0xfeedfaceu32 as _);

ptrace::cont(child, Some(Signal::SIGKILL)).unwrap();
match waitpid(child, None) {
Expand Down

0 comments on commit 213127b

Please sign in to comment.