From 06ebca013c0ce3eaa7df5cd71d787ad4e96ac2fb Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Fri, 6 Jul 2018 02:57:17 +0200 Subject: [PATCH] Reduce duplication in `encode_into` --- src/sys/socket/mod.rs | 124 +++++++++++++++++++----------------------- 1 file changed, 57 insertions(+), 67 deletions(-) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 741779b15c..7d62fea917 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -523,6 +523,7 @@ pub enum ControlMessage<'a> { /// nix::unistd::close(in_socket).unwrap(); /// ``` ScmTimestamp(&'a TimeVal), + /// Catch-all variant for unimplemented cmsg types. #[doc(hidden)] Unknown(UnknownCmsg<'a>), } @@ -566,84 +567,73 @@ impl<'a> ControlMessage<'a> { } } + /// Returns the value to put into the `cmsg_type` field of the header. + fn type_(&self) -> libc::c_int { + match *self { + ControlMessage::ScmRights(_) => libc::SCM_RIGHTS, + ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS, + ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP, + ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type, + } + } + // Unsafe: start and end of buffer must be cmsg_align'd. Updates // the provided slice; panics if the buffer is too small. unsafe fn encode_into(&self, buf: &mut [u8]) { - match *self { - ControlMessage::ScmRights(fds) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_RIGHTS, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(fds, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::ScmCredentials(creds) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_CREDENTIALS, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(creds, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); + let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self { + let &UnknownCmsg(orig_cmsg, bytes) = cmsg; + + let buf = copy_bytes(orig_cmsg, buf); + + let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - + mem::size_of_val(&orig_cmsg); + let buf = pad_bytes(padlen, buf); + + copy_bytes(bytes, buf) + } else { + let cmsg = cmsghdr { + cmsg_len: self.len() as _, + cmsg_level: libc::SOL_SOCKET, + cmsg_type: self.type_(), + }; + let buf = copy_bytes(&cmsg, buf); + + match *self { + ControlMessage::ScmRights(fds) => { + let padlen = cmsg_align(mem::size_of_val(&cmsg)) - + mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); + + copy_bytes(fds, buf) + }, + ControlMessage::ScmCredentials(creds) => { + let padlen = cmsg_align(mem::size_of_val(&cmsg)) - + mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); + + copy_bytes(creds, buf) + } + ControlMessage::ScmTimestamp(t) => { + let padlen = cmsg_align(mem::size_of_val(&cmsg)) - + mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); + + copy_bytes(t, buf) + }, + ControlMessage::Unknown(_) => unreachable!(), } - ControlMessage::ScmTimestamp(t) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_TIMESTAMP, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(t, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { - let buf = copy_bytes(orig_cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - - mem::size_of_val(&orig_cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(bytes, buf); + }; - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - } - } + let padlen = self.space() - self.len(); + pad_bytes(padlen, final_buf); } /// Decodes a `ControlMessage` from raw bytes. /// /// This is only safe to call if the data is correct for the message type /// specified in the header. Normally, the kernel ensures that this is the - /// case. + /// case. "Correct" in this case includes correct length, alignment and + /// actual content. unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> { match (header.cmsg_level, header.cmsg_type) { (libc::SOL_SOCKET, libc::SCM_RIGHTS) => {