Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix control message *decoding* and add support for ScmCredentials #923

Merged
merged 1 commit into from Aug 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
([#922](https://github.com/nix-rust/nix/pull/922))
- Support the `SO_PEERCRED` socket option and the `UnixCredentials` type on all Linux and Android targets.
([#921](https://github.com/nix-rust/nix/pull/921))
- Added support for `SCM_CREDENTIALS`, allowing to send process credentials over Unix sockets.
([#923](https://github.com/nix-rust/nix/pull/923))

### Changed

Expand Down
214 changes: 124 additions & 90 deletions src/sys/socket/mod.rs
Expand Up @@ -205,6 +205,18 @@ cfg_if! {
}
impl Eq for UnixCredentials {}

impl From<libc::ucred> for UnixCredentials {
fn from(cred: libc::ucred) -> Self {
UnixCredentials(cred)
}
}

impl Into<libc::ucred> for UnixCredentials {
fn into(self) -> libc::ucred {
self.0
}
}

impl fmt::Debug for UnixCredentials {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UnixCredentials")
Expand Down Expand Up @@ -359,7 +371,7 @@ impl<T> CmsgSpace<T> {
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct RecvMsg<'a> {
// The number of bytes received.
pub bytes: usize,
Expand All @@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> {
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
buf: self.cmsg_buffer,
next: 0
}
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct CmsgIterator<'a> {
/// Control message buffer to decode from. Must adhere to cmsg alignment.
buf: &'a [u8],
next: usize,
}

impl<'a> Iterator for CmsgIterator<'a> {
Expand All @@ -392,53 +403,27 @@ impl<'a> Iterator for CmsgIterator<'a> {
// although we handle the invariants in slightly different places to
// get a better iterator interface.
fn next(&mut self) -> Option<ControlMessage<'a>> {
let sizeof_cmsghdr = mem::size_of::<cmsghdr>();
if self.buf.len() < sizeof_cmsghdr {
if self.buf.len() == 0 {
// The iterator assumes that `self.buf` always contains exactly the
// bytes we need, so we're at the end when the buffer is empty.
return None;
}
let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) };

// This check is only in the glibc implementation of CMSG_NXTHDR
// (although it claims the kernel header checks this), but such
// a structure is clearly invalid, either way.
let cmsg_len = cmsg.cmsg_len as usize;
if cmsg_len < sizeof_cmsghdr {
return None;
}
let len = cmsg_len - sizeof_cmsghdr;
let aligned_cmsg_len = if self.next == 0 {
// CMSG_FIRSTHDR
cmsg_len
} else {
// CMSG_NXTHDR
cmsg_align(cmsg_len)
// Safe if: `self.buf` is `cmsghdr`-aligned.
let cmsg: &'a cmsghdr = unsafe {
&*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr)
};

let cmsg_len = cmsg.cmsg_len as usize;

// Advance our internal pointer.
if aligned_cmsg_len > self.buf.len() {
return None;
}
let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len];
self.buf = &self.buf[aligned_cmsg_len..];
self.next += 1;

match (cmsg.cmsg_level, cmsg.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe {
Some(ControlMessage::ScmRights(
slice::from_raw_parts(cmsg_data.as_ptr() as *const _,
cmsg_data.len() / mem::size_of::<RawFd>())))
},
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe {
Some(ControlMessage::ScmTimestamp(
&*(cmsg_data.as_ptr() as *const _)))
},
(_, _) => unsafe {
Some(ControlMessage::Unknown(UnknownCmsg(
cmsg,
slice::from_raw_parts(
cmsg_data.as_ptr() as *const _,
len))))
}
let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up above you requird self.buf to be aligned. Here, you seem to be handling the case where it isn't. Or am I misunderstanding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cmsg_align just rounds up the size of cmsghdr to the right alignment. This is needed to include the right padding to reach the payload.

self.buf = &self.buf[cmsg_align(cmsg_len)..];

// Safe if: `cmsg_data` contains the expected (amount of) content. This
// is verified by the kernel.
unsafe {
Some(ControlMessage::decode_from(cmsg, cmsg_data))
}
}
}
Expand All @@ -459,6 +444,20 @@ pub enum ControlMessage<'a> {
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
/// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of
/// a process connected to the socket.
///
/// This is similar to the socket option `SO_PEERCRED`, but requires a
/// process to explicitly send its credentials. A process running as root is
/// allowed to specify any credentials, while credentials sent by other
/// processes are verified by the kernel.
///
/// For further information, please refer to the
/// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page.
// FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials`
// and put that in here instead of a raw ucred.
#[cfg(any(target_os = "android", target_os = "linux"))]
ScmCredentials(&'a libc::ucred),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
///
Expand Down Expand Up @@ -527,6 +526,7 @@ pub enum ControlMessage<'a> {
/// nix::unistd::close(in_socket).unwrap();
/// ```
ScmTimestamp(&'a TimeVal),
/// Catch-all variant for unimplemented cmsg types.
#[doc(hidden)]
Unknown(UnknownCmsg<'a>),
}
Expand Down Expand Up @@ -558,6 +558,10 @@ impl<'a> ControlMessage<'a> {
ControlMessage::ScmRights(fds) => {
mem::size_of_val(fds)
},
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(creds) => {
mem::size_of_val(creds)
}
ControlMessage::ScmTimestamp(t) => {
mem::size_of_val(t)
},
Expand All @@ -567,57 +571,87 @@ impl<'a> ControlMessage<'a> {
}
}

/// Returns the value to put into the `cmsg_type` field of the header.
fn cmsg_type(&self) -> libc::c_int {
match *self {
ControlMessage::ScmRights(_) => libc::SCM_RIGHTS,
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS,
ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP,
ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type,
}
}

// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into(&self, buf: &mut [u8]) {
match *self {
ControlMessage::ScmRights(fds) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_RIGHTS,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(fds, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_TIMESTAMP,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(t, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
let buf = copy_bytes(orig_cmsg, buf);
let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self {
let &UnknownCmsg(orig_cmsg, bytes) = cmsg;

let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);
let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(bytes, buf);
copy_bytes(bytes, buf)
} else {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: self.cmsg_type(),
..mem::zeroed() // zero out platform-dependent padding fields
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

match *self {
ControlMessage::ScmRights(fds) => {
copy_bytes(fds, buf)
},
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(creds) => {
copy_bytes(creds, buf)
}
ControlMessage::ScmTimestamp(t) => {
copy_bytes(t, buf)
},
ControlMessage::Unknown(_) => unreachable!(),
}
};

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, final_buf);
}

/// Decodes a `ControlMessage` from raw bytes.
///
/// This is only safe to call if the data is correct for the message type
/// specified in the header. Normally, the kernel ensures that this is the
/// case. "Correct" in this case includes correct length, alignment and
/// actual content.
unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> {
match (header.cmsg_level, header.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
ControlMessage::ScmRights(
slice::from_raw_parts(data.as_ptr() as *const _,
data.len() / mem::size_of::<RawFd>()))
},
#[cfg(any(target_os = "android", target_os = "linux"))]
(libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => {
ControlMessage::ScmCredentials(
&*(data.as_ptr() as *const _)
)
}
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => {
ControlMessage::ScmTimestamp(
&*(data.as_ptr() as *const _))
},
(_, _) => {
ControlMessage::Unknown(UnknownCmsg(header, data))
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/sys/socket/sockopt.rs
Expand Up @@ -255,6 +255,8 @@ sockopt_impl!(Both, BindAny, libc::SOL_SOCKET, libc::SO_BINDANY, bool);
sockopt_impl!(Both, BindAny, libc::IPPROTO_IP, libc::IP_BINDANY, bool);
#[cfg(target_os = "linux")]
sockopt_impl!(Both, Mark, libc::SOL_SOCKET, libc::SO_MARK, u32);
#[cfg(any(target_os = "android", target_os = "linux"))]
sockopt_impl!(Both, PassCred, libc::SOL_SOCKET, libc::SO_PASSCRED, bool);

/*
*
Expand Down
15 changes: 15 additions & 0 deletions src/unistd.rs
Expand Up @@ -48,6 +48,11 @@ impl Uid {
pub fn is_root(&self) -> bool {
*self == ROOT
}

/// Get the raw `uid_t` wrapped by `self`.
pub fn as_raw(&self) -> uid_t {
self.0
}
}

impl From<Uid> for uid_t {
Expand Down Expand Up @@ -87,6 +92,11 @@ impl Gid {
pub fn effective() -> Self {
getegid()
}

/// Get the raw `gid_t` wrapped by `self`.
pub fn as_raw(&self) -> gid_t {
self.0
}
}

impl From<Gid> for gid_t {
Expand Down Expand Up @@ -123,6 +133,11 @@ impl Pid {
pub fn parent() -> Self {
getppid()
}

/// Get the raw `pid_t` wrapped by `self`.
pub fn as_raw(&self) -> pid_t {
self.0
}
}

impl From<Pid> for pid_t {
Expand Down