Skip to content

Commit

Permalink
Fix *encoding* of cmsgs and add ScmCredentials.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-schievink committed Jul 6, 2018
1 parent c1ee0d0 commit e06f901
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 45 deletions.
134 changes: 89 additions & 45 deletions src/sys/socket/mod.rs
Expand Up @@ -205,6 +205,18 @@ cfg_if! {
}
impl Eq for UnixCredentials {}

impl From<libc::ucred> for UnixCredentials {
fn from(cred: libc::ucred) -> Self {
UnixCredentials(cred)
}
}

impl Into<libc::ucred> for UnixCredentials {
fn into(self) -> libc::ucred {
self.0
}
}

impl fmt::Debug for UnixCredentials {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UnixCredentials")
Expand Down Expand Up @@ -359,7 +371,7 @@ impl<T> CmsgSpace<T> {
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct RecvMsg<'a> {
// The number of bytes received.
pub bytes: usize,
Expand All @@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> {
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
buf: self.cmsg_buffer,
next: 0
}
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct CmsgIterator<'a> {
/// Control message buffer to decode from. Must adhere to cmsg alignment.
buf: &'a [u8],
next: usize,
}

impl<'a> Iterator for CmsgIterator<'a> {
Expand All @@ -392,53 +403,25 @@ impl<'a> Iterator for CmsgIterator<'a> {
// although we handle the invariants in slightly different places to
// get a better iterator interface.
fn next(&mut self) -> Option<ControlMessage<'a>> {
let sizeof_cmsghdr = mem::size_of::<cmsghdr>();
if self.buf.len() < sizeof_cmsghdr {
if self.buf.len() == 0 {
// The iterator assumes that `self.buf` always contains exactly the
// bytes we need, so we're at the end when the buffer is empty.
return None;
}
let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) };

// This check is only in the glibc implementation of CMSG_NXTHDR
// (although it claims the kernel header checks this), but such
// a structure is clearly invalid, either way.
// Safe if: `self.buf` is `cmsghdr`-aligned.
let cmsg: &'a cmsghdr = unsafe { &*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr) };

let cmsg_len = cmsg.cmsg_len as usize;
if cmsg_len < sizeof_cmsghdr {
return None;
}
let len = cmsg_len - sizeof_cmsghdr;
let aligned_cmsg_len = if self.next == 0 {
// CMSG_FIRSTHDR
cmsg_len
} else {
// CMSG_NXTHDR
cmsg_align(cmsg_len)
};

// Advance our internal pointer.
if aligned_cmsg_len > self.buf.len() {
return None;
}
let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len];
self.buf = &self.buf[aligned_cmsg_len..];
self.next += 1;

match (cmsg.cmsg_level, cmsg.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe {
Some(ControlMessage::ScmRights(
slice::from_raw_parts(cmsg_data.as_ptr() as *const _,
cmsg_data.len() / mem::size_of::<RawFd>())))
},
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe {
Some(ControlMessage::ScmTimestamp(
&*(cmsg_data.as_ptr() as *const _)))
},
(_, _) => unsafe {
Some(ControlMessage::Unknown(UnknownCmsg(
cmsg,
slice::from_raw_parts(
cmsg_data.as_ptr() as *const _,
len))))
}
let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len];
self.buf = &self.buf[cmsg_align(cmsg_len)..];

// Safe if: `cmsg_data` contains the expected (amount of) content. This
// is verified by the kernel.
unsafe {
Some(ControlMessage::decode_from(cmsg, cmsg_data))
}
}
}
Expand All @@ -459,6 +442,19 @@ pub enum ControlMessage<'a> {
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
/// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of
/// a process connected to the socket.
///
/// This is similar to the socket option `SO_PEERCRED`, but requires a
/// process to explicitly send its credentials. A process running as root is
/// allowed to specify any credentials, while credentials sent by other
/// processes are verified by the kernel.
///
/// For further information, please refer to the
/// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page.
// FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials`
// and put that in here instead of a raw ucred.
ScmCredentials(&'a libc::ucred),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
///
Expand Down Expand Up @@ -558,6 +554,9 @@ impl<'a> ControlMessage<'a> {
ControlMessage::ScmRights(fds) => {
mem::size_of_val(fds)
},
ControlMessage::ScmCredentials(creds) => {
mem::size_of_val(creds)
}
ControlMessage::ScmTimestamp(t) => {
mem::size_of_val(t)
},
Expand Down Expand Up @@ -589,6 +588,24 @@ impl<'a> ControlMessage<'a> {
let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmCredentials(creds) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_CREDENTIALS,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(creds, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
}
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
Expand Down Expand Up @@ -621,6 +638,33 @@ impl<'a> ControlMessage<'a> {
}
}
}

/// Decodes a `ControlMessage` from raw bytes.
///
/// This is only safe to call if the data is correct for the message type
/// specified in the header. Normally, the kernel ensures that this is the
/// case.
unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> {
match (header.cmsg_level, header.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
ControlMessage::ScmRights(
slice::from_raw_parts(data.as_ptr() as *const _,
data.len() / mem::size_of::<RawFd>()))
},
(libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => {
ControlMessage::ScmCredentials(
&*(data.as_ptr() as *const _)
)
}
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => {
ControlMessage::ScmTimestamp(
&*(data.as_ptr() as *const _))
},
(_, _) => {
ControlMessage::Unknown(UnknownCmsg(header, data))
}
}
}
}


Expand Down
1 change: 1 addition & 0 deletions src/sys/socket/sockopt.rs
Expand Up @@ -255,6 +255,7 @@ sockopt_impl!(Both, BindAny, libc::SOL_SOCKET, libc::SO_BINDANY, bool);
sockopt_impl!(Both, BindAny, libc::IPPROTO_IP, libc::IP_BINDANY, bool);
#[cfg(target_os = "linux")]
sockopt_impl!(Both, Mark, libc::SOL_SOCKET, libc::SO_MARK, u32);
sockopt_impl!(Both, PassCred, libc::SOL_SOCKET, libc::SO_PASSCRED, bool);

/*
*
Expand Down
15 changes: 15 additions & 0 deletions src/unistd.rs
Expand Up @@ -48,6 +48,11 @@ impl Uid {
pub fn is_root(&self) -> bool {
*self == ROOT
}

/// Get the raw `uid_t` wrapped by `self`.
pub fn as_raw(&self) -> uid_t {
self.0
}
}

impl From<Uid> for uid_t {
Expand Down Expand Up @@ -87,6 +92,11 @@ impl Gid {
pub fn effective() -> Self {
getegid()
}

/// Get the raw `gid_t` wrapped by `self`.
pub fn as_raw(&self) -> gid_t {
self.0
}
}

impl From<Gid> for gid_t {
Expand Down Expand Up @@ -123,6 +133,11 @@ impl Pid {
pub fn parent() -> Self {
getppid()
}

/// Get the raw `pid_t` wrapped by `self`.
pub fn as_raw(&self) -> pid_t {
self.0
}
}

impl From<Pid> for pid_t {
Expand Down
127 changes: 127 additions & 0 deletions test/sys/test_socket.rs
Expand Up @@ -247,6 +247,133 @@ pub fn test_sendmsg_empty_cmsgs() {
}
}

#[test]
fn test_scm_credentials() {
use libc;
use nix::sys::uio::IoVec;
use nix::unistd::{close, getpid, getuid, getgid};
use nix::sys::socket::{socketpair, sendmsg, recvmsg, setsockopt,
AddressFamily, SockType, SockFlag,
ControlMessage, CmsgSpace, MsgFlags};
use nix::sys::socket::sockopt::PassCred;

let (send, recv) = socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty())
.unwrap();
setsockopt(recv, PassCred, &true).unwrap();

{
let iov = [IoVec::from_slice(b"hello")];
let cred = libc::ucred {
pid: getpid().as_raw(),
uid: getuid().as_raw(),
gid: getgid().as_raw(),
};
let cmsg = ControlMessage::ScmCredentials(&cred);
assert_eq!(sendmsg(send, &iov, &[cmsg], MsgFlags::empty(), None).unwrap(), 5);
close(send).unwrap();
}

{
let mut buf = [0u8; 5];
let iov = [IoVec::from_mut_slice(&mut buf[..])];
let mut cmsgspace: CmsgSpace<libc::ucred> = CmsgSpace::new();
let msg = recvmsg(recv, &iov, Some(&mut cmsgspace), MsgFlags::empty()).unwrap();
let mut received_cred = None;

for cmsg in msg.cmsgs() {
if let ControlMessage::ScmCredentials(cred) = cmsg {
assert!(received_cred.is_none());
assert_eq!(cred.pid, getpid().as_raw());
assert_eq!(cred.uid, getuid().as_raw());
assert_eq!(cred.gid, getgid().as_raw());
received_cred = Some(*cred);
} else {
panic!("unexpected cmsg");
}
}
received_cred.expect("no creds received");
assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
close(recv).unwrap();
}
}

/// Ensure that we can send `SCM_CREDENTIALS` and `SCM_RIGHTS` wit ha single
/// `sendmsg` call.
#[test]
fn test_scm_credentials_and_rights() {
use libc;
use nix::sys::uio::IoVec;
use nix::unistd::{pipe, read, write, close, getpid, getuid, getgid};
use nix::sys::socket::{socketpair, sendmsg, recvmsg, setsockopt,
AddressFamily, SockType, SockFlag,
ControlMessage, CmsgSpace, MsgFlags};
use nix::sys::socket::sockopt::PassCred;

let (send, recv) = socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty())
.unwrap();
setsockopt(recv, PassCred, &true).unwrap();

let (r, w) = pipe().unwrap();
let mut received_r: Option<RawFd> = None;

{
let iov = [IoVec::from_slice(b"hello")];
let cred = libc::ucred {
pid: getpid().as_raw(),
uid: getuid().as_raw(),
gid: getgid().as_raw(),
};
let fds = [r];
let cmsgs = [
ControlMessage::ScmCredentials(&cred),
ControlMessage::ScmRights(&fds),
];
assert_eq!(sendmsg(send, &iov, &cmsgs, MsgFlags::empty(), None).unwrap(), 5);
close(r).unwrap();
close(send).unwrap();
}

{
let mut buf = [0u8; 5];
let iov = [IoVec::from_mut_slice(&mut buf[..])];
let mut cmsgspace: CmsgSpace<(libc::ucred, CmsgSpace<RawFd>)> = CmsgSpace::new();
let msg = recvmsg(recv, &iov, Some(&mut cmsgspace), MsgFlags::empty()).unwrap();
let mut received_cred = None;

assert_eq!(msg.cmsgs().count(), 2);

for cmsg in msg.cmsgs() {
match cmsg {
ControlMessage::ScmRights(fds) => {
assert_eq!(received_r, None);
assert_eq!(fds.len(), 1);
received_r = Some(fds[0]);
}
ControlMessage::ScmCredentials(cred) => {
assert!(received_cred.is_none());
assert_eq!(cred.pid, getpid().as_raw());
assert_eq!(cred.uid, getuid().as_raw());
assert_eq!(cred.gid, getgid().as_raw());
received_cred = Some(*cred);
}
_ => panic!("unexpected cmsg"),
}
}
received_cred.expect("no creds received");
assert!(!msg.flags.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
close(recv).unwrap();
}

let received_r = received_r.expect("Did not receive passed fd");
// Ensure that the received file descriptor works
write(w, b"world").unwrap();
let mut buf = [0u8; 5];
read(received_r, &mut buf).unwrap();
assert_eq!(&buf[..], b"world");
close(received_r).unwrap();
close(w).unwrap();
}

// Test creating and using named unix domain sockets
#[test]
pub fn test_unixdomain() {
Expand Down

0 comments on commit e06f901

Please sign in to comment.