diff --git a/src/buf/buf_mut.rs b/src/buf/buf_mut.rs index 026bcad02..fb3623d25 100644 --- a/src/buf/buf_mut.rs +++ b/src/buf/buf_mut.rs @@ -1,12 +1,8 @@ -use crate::buf::{limit, Chain, Limit}; +use crate::buf::{limit, Chain, Limit, UninitSlice}; #[cfg(feature = "std")] use crate::buf::{writer, Writer}; -use core::{ - cmp, - mem::{self, MaybeUninit}, - ptr, usize, -}; +use core::{cmp, mem, ptr, usize}; use alloc::{boxed::Box, vec::Vec}; @@ -73,19 +69,14 @@ pub unsafe trait BufMut { /// /// let mut buf = Vec::with_capacity(16); /// - /// unsafe { - /// // MaybeUninit::as_mut_ptr - /// buf.bytes_mut()[0].as_mut_ptr().write(b'h'); - /// buf.bytes_mut()[1].as_mut_ptr().write(b'e'); + /// // Write some data + /// buf.bytes_mut()[0..2].copy_from_slice(b"he"); + /// unsafe { buf.advance_mut(2) }; /// - /// buf.advance_mut(2); + /// // write more bytes + /// buf.bytes_mut()[0..3].copy_from_slice(b"llo"); /// - /// buf.bytes_mut()[0].as_mut_ptr().write(b'l'); - /// buf.bytes_mut()[1].as_mut_ptr().write(b'l'); - /// buf.bytes_mut()[2].as_mut_ptr().write(b'o'); - /// - /// buf.advance_mut(3); - /// } + /// unsafe { buf.advance_mut(3); } /// /// assert_eq!(5, buf.len()); /// assert_eq!(buf, b"hello"); @@ -144,14 +135,14 @@ pub unsafe trait BufMut { /// /// unsafe { /// // MaybeUninit::as_mut_ptr - /// buf.bytes_mut()[0].as_mut_ptr().write(b'h'); - /// buf.bytes_mut()[1].as_mut_ptr().write(b'e'); + /// buf.bytes_mut()[0..].as_mut_ptr().write(b'h'); + /// buf.bytes_mut()[1..].as_mut_ptr().write(b'e'); /// /// buf.advance_mut(2); /// - /// buf.bytes_mut()[0].as_mut_ptr().write(b'l'); - /// buf.bytes_mut()[1].as_mut_ptr().write(b'l'); - /// buf.bytes_mut()[2].as_mut_ptr().write(b'o'); + /// buf.bytes_mut()[0..].as_mut_ptr().write(b'l'); + /// buf.bytes_mut()[1..].as_mut_ptr().write(b'l'); + /// buf.bytes_mut()[2..].as_mut_ptr().write(b'o'); /// /// buf.advance_mut(3); /// } @@ -167,7 +158,7 @@ pub unsafe trait BufMut { /// `bytes_mut` returning an empty slice implies that `remaining_mut` will /// return 0 and `remaining_mut` returning 0 implies that `bytes_mut` will /// return an empty slice. - fn bytes_mut(&mut self) -> &mut [MaybeUninit]; + fn bytes_mut(&mut self) -> &mut UninitSlice; /// Transfer bytes into `self` from `src` and advance the cursor by the /// number of bytes written. @@ -922,7 +913,7 @@ macro_rules! deref_forward_bufmut { (**self).remaining_mut() } - fn bytes_mut(&mut self) -> &mut [MaybeUninit] { + fn bytes_mut(&mut self) -> &mut UninitSlice { (**self).bytes_mut() } @@ -1007,9 +998,9 @@ unsafe impl BufMut for &mut [u8] { } #[inline] - fn bytes_mut(&mut self) -> &mut [MaybeUninit] { - // MaybeUninit is repr(transparent), so safe to transmute - unsafe { mem::transmute(&mut **self) } + fn bytes_mut(&mut self) -> &mut UninitSlice { + // UninitSlice is repr(transparent), so safe to transmute + unsafe { &mut *(*self as *mut [u8] as *mut _) } } #[inline] @@ -1042,9 +1033,7 @@ unsafe impl BufMut for Vec { } #[inline] - fn bytes_mut(&mut self) -> &mut [MaybeUninit] { - use core::slice; - + fn bytes_mut(&mut self) -> &mut UninitSlice { if self.capacity() == self.len() { self.reserve(64); // Grow the vec } @@ -1052,8 +1041,8 @@ unsafe impl BufMut for Vec { let cap = self.capacity(); let len = self.len(); - let ptr = self.as_mut_ptr() as *mut MaybeUninit; - unsafe { &mut slice::from_raw_parts_mut(ptr, cap)[len..] } + let ptr = self.as_mut_ptr(); + unsafe { &mut UninitSlice::from_raw_parts_mut(ptr, cap)[len..] } } // Specialize these methods so they can skip checking `remaining_mut` diff --git a/src/buf/chain.rs b/src/buf/chain.rs index cc2c944b7..e59667dff 100644 --- a/src/buf/chain.rs +++ b/src/buf/chain.rs @@ -1,8 +1,6 @@ -use crate::buf::IntoIter; +use crate::buf::{IntoIter, UninitSlice}; use crate::{Buf, BufMut}; -use core::mem::MaybeUninit; - #[cfg(feature = "std")] use std::io::IoSlice; @@ -183,7 +181,7 @@ where self.a.remaining_mut() + self.b.remaining_mut() } - fn bytes_mut(&mut self) -> &mut [MaybeUninit] { + fn bytes_mut(&mut self) -> &mut UninitSlice { if self.a.has_remaining_mut() { self.a.bytes_mut() } else { diff --git a/src/buf/limit.rs b/src/buf/limit.rs index c6ed3c7b1..5cbbbfe6b 100644 --- a/src/buf/limit.rs +++ b/src/buf/limit.rs @@ -1,6 +1,7 @@ +use crate::buf::UninitSlice; use crate::BufMut; -use core::{cmp, mem::MaybeUninit}; +use core::cmp; /// A `BufMut` adapter which limits the amount of bytes that can be written /// to an underlying buffer. @@ -60,7 +61,7 @@ unsafe impl BufMut for Limit { cmp::min(self.inner.remaining_mut(), self.limit) } - fn bytes_mut(&mut self) -> &mut [MaybeUninit] { + fn bytes_mut(&mut self) -> &mut UninitSlice { let bytes = self.inner.bytes_mut(); let end = cmp::min(bytes.len(), self.limit); &mut bytes[..end] diff --git a/src/buf/mod.rs b/src/buf/mod.rs index 5c6d5f9d5..c4c0a5724 100644 --- a/src/buf/mod.rs +++ b/src/buf/mod.rs @@ -24,6 +24,7 @@ mod limit; #[cfg(feature = "std")] mod reader; mod take; +mod uninit_slice; mod vec_deque; #[cfg(feature = "std")] mod writer; @@ -34,6 +35,7 @@ pub use self::chain::Chain; pub use self::iter::IntoIter; pub use self::limit::Limit; pub use self::take::Take; +pub use self::uninit_slice::UninitSlice; #[cfg(feature = "std")] pub use self::{reader::Reader, writer::Writer}; diff --git a/src/buf/uninit_slice.rs b/src/buf/uninit_slice.rs new file mode 100644 index 000000000..32ebde4c5 --- /dev/null +++ b/src/buf/uninit_slice.rs @@ -0,0 +1,176 @@ +use core::fmt; +use core::mem::MaybeUninit; +use core::ops::{ + Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, +}; + +/// Uninitialized byte slice. +/// +/// Returned by `BufMut::bytes_mut()`, the referenced byte slice may be +/// uninitialized. The wrapper provides safe access without introducing +/// undefined behavior. +/// +/// The safety invariants of this wrapper are: +/// +/// 1. Reading from an `UninitSlice` is undefined behavior. +/// 2. Writing uninitialized bytes to an `UninitSlice` is undefined behavior. +/// +/// The difference between `&mut UninitSlice` and `&mut [MaybeUninit]` is +/// that it is possible in safe code to write uninitialized bytes to an +/// `&mut [MaybeUninit]`, which this type prohibits. +#[repr(transparent)] +pub struct UninitSlice([MaybeUninit]); + +impl UninitSlice { + /// Create a `&mut UninitSlice` from a pointer and a length. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` references a valid memory region owned + /// by the caller representing a byte slice for the duration of `'a`. + /// + /// # Examples + /// + /// ``` + /// use bytes::buf::UninitSlice; + /// + /// let bytes = b"hello world".to_vec(); + /// let ptr = bytes.as_ptr() as *mut _; + /// let len = bytes.len(); + /// + /// let slice = unsafe { UninitSlice::from_raw_parts_mut(ptr, len) }; + /// ``` + pub unsafe fn from_raw_parts_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut UninitSlice { + let maybe_init: &mut [MaybeUninit] = + core::slice::from_raw_parts_mut(ptr as *mut _, len); + &mut *(maybe_init as *mut [MaybeUninit] as *mut UninitSlice) + } + + /// Write a single byte at the specified offset. + /// + /// # Panics + /// + /// The function panics if `index` is out of bounds. + /// + /// # Examples + /// + /// ``` + /// use bytes::buf::UninitSlice; + /// + /// let mut data = [b'f', b'o', b'o']; + /// let slice = unsafe { UninitSlice::from_raw_parts_mut(data.as_mut_ptr(), 3) }; + /// + /// slice.write_byte(0, b'b'); + /// + /// assert_eq!(b"boo", &data[..]); + /// ``` + pub fn write_byte(&mut self, index: usize, byte: u8) { + assert!(index < self.len()); + + unsafe { self[index..].as_mut_ptr().write(byte) } + } + + /// Copies bytes from `src` into `self`. + /// + /// The length of `src` must be the same as `self`. + /// + /// # Panics + /// + /// The function panics if `src` has a different length than `self`. + /// + /// # Examples + /// + /// ``` + /// use bytes::buf::UninitSlice; + /// + /// let mut data = [b'f', b'o', b'o']; + /// let slice = unsafe { UninitSlice::from_raw_parts_mut(data.as_mut_ptr(), 3) }; + /// + /// slice.copy_from_slice(b"bar"); + /// + /// assert_eq!(b"bar", &data[..]); + /// ``` + pub fn copy_from_slice(&mut self, src: &[u8]) { + use core::ptr; + + assert_eq!(self.len(), src.len()); + + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), self.as_mut_ptr(), self.len()); + } + } + + /// Return a raw pointer to the slice's buffer. + /// + /// # Safety + /// + /// The caller **must not** read from the referenced memory and **must not** + /// write **uninitialized** bytes to the slice either. + /// + /// # Examples + /// + /// ``` + /// use bytes::BufMut; + /// + /// let mut data = [0, 1, 2]; + /// let mut slice = &mut data[..]; + /// let ptr = BufMut::bytes_mut(&mut slice).as_mut_ptr(); + /// ``` + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() as *mut _ + } + + /// Returns the number of bytes in the slice. + /// + /// # Examples + /// + /// ``` + /// use bytes::BufMut; + /// + /// let mut data = [0, 1, 2]; + /// let mut slice = &mut data[..]; + /// let len = BufMut::bytes_mut(&mut slice).len(); + /// + /// assert_eq!(len, 3); + /// ``` + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl fmt::Debug for UninitSlice { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UninitSlice[...]").finish() + } +} + +macro_rules! impl_index { + ($($t:ty),*) => { + $( + impl Index<$t> for UninitSlice { + type Output = UninitSlice; + + fn index(&self, index: $t) -> &UninitSlice { + let maybe_uninit: &[MaybeUninit] = &self.0[index]; + unsafe { &*(maybe_uninit as *const [MaybeUninit] as *const UninitSlice) } + } + } + + impl IndexMut<$t> for UninitSlice { + fn index_mut(&mut self, index: $t) -> &mut UninitSlice { + let maybe_uninit: &mut [MaybeUninit] = &mut self.0[index]; + unsafe { &mut *(maybe_uninit as *mut [MaybeUninit] as *mut UninitSlice) } + } + } + )* + }; +} + +impl_index!( + Range, + RangeFrom, + RangeFull, + RangeInclusive, + RangeTo, + RangeToInclusive +); diff --git a/src/bytes_mut.rs b/src/bytes_mut.rs index 16cb72c2b..38f1ed53d 100644 --- a/src/bytes_mut.rs +++ b/src/bytes_mut.rs @@ -11,7 +11,7 @@ use alloc::{ vec::Vec, }; -use crate::buf::IntoIter; +use crate::buf::{IntoIter, UninitSlice}; use crate::bytes::Vtable; #[allow(unused)] use crate::loom::sync::atomic::AtomicMut; @@ -684,7 +684,7 @@ impl BytesMut { self.reserve(cnt); unsafe { - let dst = self.maybe_uninit_bytes(); + let dst = self.uninit_slice(); // Reserved above debug_assert!(dst.len() >= cnt); @@ -910,12 +910,12 @@ impl BytesMut { } #[inline] - fn maybe_uninit_bytes(&mut self) -> &mut [mem::MaybeUninit] { + fn uninit_slice(&mut self) -> &mut UninitSlice { unsafe { let ptr = self.ptr.as_ptr().offset(self.len as isize); let len = self.cap - self.len; - slice::from_raw_parts_mut(ptr as *mut mem::MaybeUninit, len) + UninitSlice::from_raw_parts_mut(ptr, len) } } } @@ -985,11 +985,11 @@ unsafe impl BufMut for BytesMut { } #[inline] - fn bytes_mut(&mut self) -> &mut [mem::MaybeUninit] { + fn bytes_mut(&mut self) -> &mut UninitSlice { if self.capacity() == self.len() { self.reserve(64); } - self.maybe_uninit_bytes() + self.uninit_slice() } // Specialize these methods so they can skip checking `remaining_mut` diff --git a/tests/test_buf_mut.rs b/tests/test_buf_mut.rs index e9948839a..406ec510c 100644 --- a/tests/test_buf_mut.rs +++ b/tests/test_buf_mut.rs @@ -1,5 +1,6 @@ #![warn(rust_2018_idioms)] +use bytes::buf::UninitSlice; use bytes::{BufMut, BytesMut}; use core::fmt::Write; use core::usize; @@ -80,7 +81,7 @@ fn test_deref_bufmut_forwards() { unreachable!("remaining_mut"); } - fn bytes_mut(&mut self) -> &mut [std::mem::MaybeUninit] { + fn bytes_mut(&mut self) -> &mut UninitSlice { unreachable!("bytes_mut"); } @@ -99,3 +100,30 @@ fn test_deref_bufmut_forwards() { (Box::new(Special) as Box).put_u8(b'x'); Box::new(Special).put_u8(b'x'); } + +#[test] +#[should_panic] +fn write_byte_panics_if_out_of_bounds() { + let mut data = [b'b', b'a', b'r']; + + let slice = unsafe { UninitSlice::from_raw_parts_mut(data.as_mut_ptr(), 3) }; + slice.write_byte(4, b'f'); +} + +#[test] +#[should_panic] +fn copy_from_slice_panics_if_different_length_1() { + let mut data = [b'b', b'a', b'r']; + + let slice = unsafe { UninitSlice::from_raw_parts_mut(data.as_mut_ptr(), 3) }; + slice.copy_from_slice(b"a"); +} + +#[test] +#[should_panic] +fn copy_from_slice_panics_if_different_length_2() { + let mut data = [b'b', b'a', b'r']; + + let slice = unsafe { UninitSlice::from_raw_parts_mut(data.as_mut_ptr(), 3) }; + slice.copy_from_slice(b"abcd"); +} diff --git a/tests/test_bytes.rs b/tests/test_bytes.rs index 6b106a6bc..b97cce6cb 100644 --- a/tests/test_bytes.rs +++ b/tests/test_bytes.rs @@ -912,12 +912,12 @@ fn bytes_buf_mut_advance() { let mut bytes = BytesMut::with_capacity(1024); unsafe { - let ptr = bytes.bytes_mut().as_ptr(); + let ptr = bytes.bytes_mut().as_mut_ptr(); assert_eq!(1024, bytes.bytes_mut().len()); bytes.advance_mut(10); - let next = bytes.bytes_mut().as_ptr(); + let next = bytes.bytes_mut().as_mut_ptr(); assert_eq!(1024 - 10, bytes.bytes_mut().len()); assert_eq!(ptr.offset(10), next);