Skip to content

Commit

Permalink
Reduce duplication in encode_into
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-schievink committed Jul 6, 2018
1 parent e06f901 commit 06ebca0
Showing 1 changed file with 57 additions and 67 deletions.
124 changes: 57 additions & 67 deletions src/sys/socket/mod.rs
Expand Up @@ -523,6 +523,7 @@ pub enum ControlMessage<'a> {
/// nix::unistd::close(in_socket).unwrap();
/// ```
ScmTimestamp(&'a TimeVal),
/// Catch-all variant for unimplemented cmsg types.
#[doc(hidden)]
Unknown(UnknownCmsg<'a>),
}
Expand Down Expand Up @@ -566,84 +567,73 @@ impl<'a> ControlMessage<'a> {
}
}

/// Returns the value to put into the `cmsg_type` field of the header.
fn type_(&self) -> libc::c_int {
match *self {
ControlMessage::ScmRights(_) => libc::SCM_RIGHTS,
ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS,
ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP,
ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type,
}
}

// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into(&self, buf: &mut [u8]) {
match *self {
ControlMessage::ScmRights(fds) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_RIGHTS,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(fds, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmCredentials(creds) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_CREDENTIALS,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(creds, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self {
let &UnknownCmsg(orig_cmsg, bytes) = cmsg;

let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

copy_bytes(bytes, buf)
} else {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: self.type_(),
};
let buf = copy_bytes(&cmsg, buf);

match *self {
ControlMessage::ScmRights(fds) => {
let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

copy_bytes(fds, buf)
},
ControlMessage::ScmCredentials(creds) => {
let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

copy_bytes(creds, buf)
}
ControlMessage::ScmTimestamp(t) => {
let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

copy_bytes(t, buf)
},
ControlMessage::Unknown(_) => unreachable!(),
}
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_TIMESTAMP,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(t, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(bytes, buf);
};

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
}
}
let padlen = self.space() - self.len();
pad_bytes(padlen, final_buf);
}

/// Decodes a `ControlMessage` from raw bytes.
///
/// This is only safe to call if the data is correct for the message type
/// specified in the header. Normally, the kernel ensures that this is the
/// case.
/// case. "Correct" in this case includes correct length, alignment and
/// actual content.
unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> {
match (header.cmsg_level, header.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
Expand Down

0 comments on commit 06ebca0

Please sign in to comment.