diff --git a/build.rs b/build.rs index 657b8c2..656c569 100644 --- a/build.rs +++ b/build.rs @@ -65,6 +65,7 @@ fn main() { if rustc < 79 { println!("cargo:rustc-cfg=no_literal_byte_character"); + println!("cargo:rustc-cfg=no_literal_c_string"); } if !cfg!(feature = "proc-macro") { diff --git a/src/fallback.rs b/src/fallback.rs index 3aa3c46..b42537b 100644 --- a/src/fallback.rs +++ b/src/fallback.rs @@ -15,7 +15,8 @@ use core::mem::ManuallyDrop; use core::ops::Range; use core::ops::RangeBounds; use core::ptr; -use core::str::FromStr; +use core::str::{self, FromStr}; +use std::ffi::CStr; #[cfg(procmacro2_semver_exempt)] use std::path::PathBuf; @@ -1010,27 +1011,7 @@ impl Literal { pub fn string(string: &str) -> Literal { let mut repr = String::with_capacity(string.len() + 2); repr.push('"'); - let mut chars = string.chars(); - while let Some(ch) = chars.next() { - if ch == '\0' { - repr.push_str( - if chars - .as_str() - .starts_with(|next| '0' <= next && next <= '7') - { - // circumvent clippy::octal_escapes lint - r"\x00" - } else { - r"\0" - }, - ); - } else if ch == '\'' { - // escape_debug turns this into "\'" which is unnecessary. - repr.push(ch); - } else { - repr.extend(ch.escape_debug()); - } - } + escape_utf8(string, &mut repr); repr.push('"'); Literal::_new(repr) } @@ -1093,6 +1074,34 @@ impl Literal { Literal::_new(repr) } + pub fn c_string(string: &CStr) -> Literal { + let mut repr = "c\"".to_string(); + let mut bytes = string.to_bytes(); + while !bytes.is_empty() { + let (valid, invalid) = match str::from_utf8(bytes) { + Ok(all_valid) => { + bytes = b""; + (all_valid, bytes) + } + Err(utf8_error) => { + let (valid, rest) = bytes.split_at(utf8_error.valid_up_to()); + let valid = str::from_utf8(valid).unwrap(); + let invalid = utf8_error + .error_len() + .map_or(rest, |error_len| &rest[..error_len]); + bytes = &bytes[valid.len() + invalid.len()..]; + (valid, invalid) + } + }; + escape_utf8(valid, &mut repr); + for &byte in invalid { + let _ = write!(repr, r"\x{:02X}", byte); + } + } + repr.push('"'); + Literal::_new(repr) + } + pub fn span(&self) -> Span { self.span } @@ -1191,3 +1200,27 @@ impl Debug for Literal { debug.finish() } } + +fn escape_utf8(string: &str, repr: &mut String) { + let mut chars = string.chars(); + while let Some(ch) = chars.next() { + if ch == '\0' { + repr.push_str( + if chars + .as_str() + .starts_with(|next| '0' <= next && next <= '7') + { + // circumvent clippy::octal_escapes lint + r"\x00" + } else { + r"\0" + }, + ); + } else if ch == '\'' { + // escape_debug turns this into "\'" which is unnecessary. + repr.push(ch); + } else { + repr.extend(ch.escape_debug()); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 3d4ed0d..07344af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,6 +170,7 @@ use core::ops::Range; use core::ops::RangeBounds; use core::str::FromStr; use std::error::Error; +use std::ffi::CStr; #[cfg(procmacro2_semver_exempt)] use std::path::PathBuf; @@ -1244,6 +1245,11 @@ impl Literal { Literal::_new(imp::Literal::byte_string(bytes)) } + /// C string literal. + pub fn c_string(string: &CStr) -> Literal { + Literal::_new(imp::Literal::c_string(string)) + } + /// Returns the span encompassing this literal. pub fn span(&self) -> Span { Span::_new(self.inner.span()) diff --git a/src/wrapper.rs b/src/wrapper.rs index f7c0377..87e348d 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -7,6 +7,7 @@ use core::fmt::{self, Debug, Display}; use core::ops::Range; use core::ops::RangeBounds; use core::str::FromStr; +use std::ffi::CStr; use std::panic; #[cfg(super_unstable)] use std::path::PathBuf; @@ -889,6 +890,25 @@ impl Literal { } } + pub fn c_string(string: &CStr) -> Literal { + if inside_proc_macro() { + Literal::Compiler({ + #[cfg(not(no_literal_c_string))] + { + proc_macro::Literal::c_string(string) + } + + #[cfg(no_literal_c_string)] + { + let fallback = fallback::Literal::c_string(string); + fallback.repr.parse::().unwrap() + } + }) + } else { + Literal::Fallback(fallback::Literal::c_string(string)) + } + } + pub fn span(&self) -> Span { match self { Literal::Compiler(lit) => Span::Compiler(lit.span()), diff --git a/tests/test.rs b/tests/test.rs index 62ee09e..97a0f9b 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -6,6 +6,7 @@ )] use proc_macro2::{Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; +use std::ffi::CStr; use std::iter; use std::str::{self, FromStr}; @@ -164,6 +165,29 @@ fn literal_byte_string() { #[test] fn literal_c_string() { + assert_eq!(Literal::c_string(<&CStr>::default()).to_string(), "c\"\""); + + let cstr = CStr::from_bytes_with_nul(b"aA\0").unwrap(); + assert_eq!(Literal::c_string(cstr).to_string(), r#" c"aA" "#.trim()); + + let cstr = CStr::from_bytes_with_nul(b"\t\0").unwrap(); + assert_eq!(Literal::c_string(cstr).to_string(), r#" c"\t" "#.trim()); + + let cstr = CStr::from_bytes_with_nul(b"\xE2\x9D\xA4\0").unwrap(); + assert_eq!(Literal::c_string(cstr).to_string(), r#" c"❤" "#.trim()); + + let cstr = CStr::from_bytes_with_nul(b"'\0").unwrap(); + assert_eq!(Literal::c_string(cstr).to_string(), r#" c"'" "#.trim()); + + let cstr = CStr::from_bytes_with_nul(b"\"\0").unwrap(); + assert_eq!(Literal::c_string(cstr).to_string(), r#" c"\"" "#.trim()); + + let cstr = CStr::from_bytes_with_nul(b"\x7F\xFF\xFE\xCC\xB3\0").unwrap(); + assert_eq!( + Literal::c_string(cstr).to_string(), + r#" c"\u{7f}\xFF\xFE\u{333}" "#.trim(), + ); + let strings = r###" c"hello\x80我叫\u{1F980}" // from the RFC cr"\"