diff --git a/src/ensure.rs b/src/ensure.rs index fb45baa..4f05a5c 100644 --- a/src/ensure.rs +++ b/src/ensure.rs @@ -47,20 +47,24 @@ impl Buf { impl Write for Buf { fn write_str(&mut self, s: &str) -> fmt::Result { + if s.bytes().any(|b| b == b' ' || b == b'\n') { + return Err(fmt::Error); + } + let remaining = self.bytes.len() - self.written; - if s.len() <= remaining { - unsafe { - ptr::copy_nonoverlapping( - s.as_ptr(), - self.bytes.as_mut_ptr().add(self.written).cast::(), - s.len(), - ); - } - self.written += s.len(); - Ok(()) - } else { - Err(fmt::Error) + if s.len() > remaining { + return Err(fmt::Error); + } + + unsafe { + ptr::copy_nonoverlapping( + s.as_ptr(), + self.bytes.as_mut_ptr().add(self.written).cast::(), + s.len(), + ); } + self.written += s.len(); + Ok(()) } } diff --git a/tests/test_ensure.rs b/tests/test_ensure.rs index ac69ce9..37d56cb 100644 --- a/tests/test_ensure.rs +++ b/tests/test_ensure.rs @@ -366,6 +366,19 @@ fn test_trailer() { ); } +#[test] +fn test_whitespace() { + #[derive(Debug)] + pub struct Point { + pub x: i32, + pub y: i32, + } + + let point = Point { x: 0, y: 0 }; + let test = || Ok(ensure!("" == format!("{:#?}", point))); + assert_err(test, "Condition failed: `\"\" == format!(\"{:#?}\", point)`"); +} + #[test] fn test_too_long() { let test = || Ok(ensure!("" == "x".repeat(10)));