From 041113c1a784ba8f45eac4ecc5598dd49e44d02f Mon Sep 17 00:00:00 2001 From: Zachary Dremann Date: Tue, 10 Oct 2023 01:29:04 -0400 Subject: [PATCH 1/3] Use memchr to search for characters to escape --- src/escapei.rs | 39 ++++--- src/se/simple_type.rs | 237 ++++++++++++++++++++++++------------------ src/utils.rs | 57 ++++++++++ 3 files changed, 218 insertions(+), 115 deletions(-) diff --git a/src/escapei.rs b/src/escapei.rs index 46b75f50..107d6871 100644 --- a/src/escapei.rs +++ b/src/escapei.rs @@ -1,9 +1,10 @@ //! Manage xml character escapes -use memchr::memchr2_iter; +use memchr::{memchr2_iter, memchr3_iter}; use std::borrow::Cow; use std::ops::Range; +use crate::utils::MergeIter; #[cfg(test)] use pretty_assertions::assert_eq; @@ -72,7 +73,14 @@ impl std::error::Error for EscapeError {} /// | `'` | `'` /// | `"` | `"` pub fn escape(raw: &str) -> Cow { - _escape(raw, |ch| matches!(ch, b'<' | b'>' | b'&' | b'\'' | b'\"')) + let bytes = raw.as_bytes(); + _escape( + raw, + MergeIter::new( + memchr3_iter(b'<', b'>', b'&', bytes), + memchr2_iter(b'\'', b'"', bytes), + ), + ) } /// Escapes an `&str` and replaces xml special characters (`<`, `>`, `&`) @@ -89,24 +97,23 @@ pub fn escape(raw: &str) -> Cow { /// | `>` | `>` /// | `&` | `&` pub fn partial_escape(raw: &str) -> Cow { - _escape(raw, |ch| matches!(ch, b'<' | b'>' | b'&')) + _escape(raw, memchr3_iter(b'<', b'>', b'&', raw.as_bytes())) } /// Escapes an `&str` and replaces a subset of xml special characters (`<`, `>`, /// `&`, `'`, `"`) with their corresponding xml escaped value. -pub(crate) fn _escape bool>(raw: &str, escape_chars: F) -> Cow { +pub(crate) fn _escape(raw: &str, escapes: It) -> Cow +where + It: Iterator, +{ let bytes = raw.as_bytes(); let mut escaped = None; - let mut iter = bytes.iter(); - let mut pos = 0; - while let Some(i) = iter.position(|&b| escape_chars(b)) { - if escaped.is_none() { - escaped = Some(Vec::with_capacity(raw.len())); - } - let escaped = escaped.as_mut().expect("initialized"); - let new_pos = pos + i; - escaped.extend_from_slice(&bytes[pos..new_pos]); - match bytes[new_pos] { + let mut last_pos = 0; + for i in escapes { + let escaped = escaped.get_or_insert_with(|| Vec::with_capacity(raw.len())); + let byte = bytes[i]; + escaped.extend_from_slice(&bytes[last_pos..i]); + match byte { b'<' => escaped.extend_from_slice(b"<"), b'>' => escaped.extend_from_slice(b">"), b'\'' => escaped.extend_from_slice(b"'"), @@ -124,11 +131,11 @@ pub(crate) fn _escape bool>(raw: &str, escape_chars: F) -> Cow "Only '<', '>','\', '&', '\"', '\\t', '\\r', '\\n', and ' ' are escaped" ), } - pos = new_pos + 1; + last_pos = i + 1; } if let Some(mut escaped) = escaped { - if let Some(raw) = bytes.get(pos..) { + if let Some(raw) = bytes.get(last_pos..) { escaped.extend_from_slice(raw); } // SAFETY: we operate on UTF-8 input and search for an one byte chars only, diff --git a/src/se/simple_type.rs b/src/se/simple_type.rs index 58a4ca3b..2df415d5 100644 --- a/src/se/simple_type.rs +++ b/src/se/simple_type.rs @@ -6,6 +6,8 @@ use crate::errors::serialize::DeError; use crate::escapei::_escape; use crate::se::{Indent, QuoteLevel}; +use crate::utils::MergeIter; +use memchr::{memchr2_iter, memchr3_iter, memchr_iter}; use serde::ser::{ Impossible, Serialize, SerializeSeq, SerializeTuple, SerializeTupleStruct, Serializer, }; @@ -29,67 +31,96 @@ fn escape_item(value: &str, target: QuoteTarget, level: QuoteLevel) -> Cow use QuoteLevel::*; use QuoteTarget::*; + let bytes = value.as_bytes(); + match (target, level) { - (_, Full) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items, cannot be used in the item - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' | b'>' | b'\'' | b'\"' => true, - _ => false, - }), + (_, Full) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<', '>', '\'', '"': Required characters to escape + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr3_iter(b'>', b'\'', b'"', bytes), + ), + ), //---------------------------------------------------------------------- - (Text, Partial) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items, cannot be used in the item - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' | b'>' => true, - _ => false, - }), - (Text, Minimal) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items, cannot be used in the item - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' => true, - _ => false, - }), + (Text, Partial) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<', '>': Required characters to escape + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr_iter(b'>', bytes), + ), + ), + (Text, Minimal) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<': Required characters to escape + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + ), //---------------------------------------------------------------------- - (DoubleQAttr, Partial) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items, cannot be used in the item - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' | b'>' => true, - // Double quoted attribute should escape quote - b'"' => true, - _ => false, - }), - (DoubleQAttr, Minimal) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items, cannot be used in the item - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' => true, - // Double quoted attribute should escape quote - b'"' => true, - _ => false, - }), + (DoubleQAttr, Partial) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<', '>': Required characters to escape + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr2_iter(b'>', b'"', bytes), + ), + ), + (DoubleQAttr, Minimal) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<': Required characters to escape + // '"': Double quoted attribute should escape quote + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr_iter(b'"', bytes), + ), + ), //---------------------------------------------------------------------- - (SingleQAttr, Partial) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' | b'>' => true, - // Single quoted attribute should escape quote - b'\'' => true, - _ => false, - }), - (SingleQAttr, Minimal) => _escape(value, |ch| match ch { - // Spaces used as delimiters of list items - b' ' | b'\r' | b'\n' | b'\t' => true, - // Required characters to escape - b'&' | b'<' => true, - // Single quoted attribute should escape quote - b'\'' => true, - _ => false, - }), + (SingleQAttr, Partial) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<', '>': Required characters to escape + // '\'': Single quoted attribute should escape quote + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr2_iter(b'>', b'\'', bytes), + ), + ), + (SingleQAttr, Minimal) => _escape( + value, + // ' ', '\r', '\n', '\t': Spaces used as delimiters of list items, cannot be used in the item + // '&', '<': Required characters to escape + // '\'': Single quoted attribute should escape quote + MergeIter::new( + MergeIter::new( + memchr3_iter(b' ', b'\r', b'\n', bytes), + memchr3_iter(b'\t', b'&', b'<', bytes), + ), + memchr_iter(b'\'', bytes), + ), + ), } } @@ -98,53 +129,61 @@ fn escape_list(value: &str, target: QuoteTarget, level: QuoteLevel) -> Cow use QuoteLevel::*; use QuoteTarget::*; + let bytes = value.as_bytes(); + match (target, level) { - (_, Full) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' | b'>' | b'\'' | b'\"' => true, - _ => false, - }), + (_, Full) => _escape( + value, + // '&', '<', '>', '\'', '"': Required characters to escape + MergeIter::new( + memchr3_iter(b'&', b'<', b'>', bytes), + memchr2_iter(b'\'', b'"', bytes), + ), + ), //---------------------------------------------------------------------- - (Text, Partial) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' | b'>' => true, - _ => false, - }), - (Text, Minimal) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' => true, - _ => false, - }), + (Text, Partial) => _escape( + value, + // '&', '<', '>': Required characters to escape + memchr3_iter(b'&', b'<', b'>', bytes), + ), + (Text, Minimal) => _escape( + value, + // '&', '<': Required characters to escape + memchr2_iter(b'&', b'<', bytes), + ), //---------------------------------------------------------------------- - (DoubleQAttr, Partial) => _escape(value, |ch| match ch { + (DoubleQAttr, Partial) => _escape( + value, + // '&', '<', '>': Required characters to escape + // '"': Double quoted attribute should escape quote + MergeIter::new( + memchr3_iter(b'&', b'<', b'>', bytes), + memchr_iter(b'"', bytes), + ), + ), + (DoubleQAttr, Minimal) => _escape( + value, + // '&', '<': Required characters to escape + // '"': Double quoted attribute should escape quote // Required characters to escape - b'&' | b'<' | b'>' => true, - // Double quoted attribute should escape quote - b'"' => true, - _ => false, - }), - (DoubleQAttr, Minimal) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' => true, - // Double quoted attribute should escape quote - b'"' => true, - _ => false, - }), + memchr3_iter(b'&', b'<', b'"', bytes), + ), //---------------------------------------------------------------------- - (SingleQAttr, Partial) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' | b'>' => true, - // Single quoted attribute should escape quote - b'\'' => true, - _ => false, - }), - (SingleQAttr, Minimal) => _escape(value, |ch| match ch { - // Required characters to escape - b'&' | b'<' => true, - // Single quoted attribute should escape quote - b'\'' => true, - _ => false, - }), + (SingleQAttr, Partial) => _escape( + value, + // '&', '<', '>': Required characters to escape + // '\'': Single quoted attribute should escape quote + MergeIter::new( + memchr3_iter(b'&', b'<', b'>', bytes), + memchr_iter(b'\'', bytes), + ), + ), + (SingleQAttr, Minimal) => _escape( + value, + // '&', '<': Required characters to escape + // '\': Single quoted attribute should escape quote + memchr3_iter(b'&', b'<', b'\'', bytes), + ), } } diff --git a/src/utils.rs b/src/utils.rs index adf22899..9719f88a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,6 @@ use std::borrow::{Borrow, Cow}; use std::fmt::{self, Debug, Formatter}; +use std::iter::Peekable; use std::ops::Deref; #[cfg(feature = "serialize")] @@ -35,6 +36,62 @@ pub fn write_byte_string(f: &mut Formatter, byte_string: &[u8]) -> fmt::Result { Ok(()) } +pub(crate) struct MergeIter +where + It1: Iterator, + It2: Iterator, +{ + it1: Peekable, + it2: Peekable, +} + +impl MergeIter +where + It1: Iterator, + It2: Iterator, +{ + pub fn new(it1: It1, it2: It2) -> Self { + Self { + it1: it1.peekable(), + it2: it2.peekable(), + } + } +} + +impl Iterator for MergeIter +where + It1: Iterator, + It2: Iterator, + It1::Item: PartialOrd, +{ + type Item = It1::Item; + + fn next(&mut self) -> Option { + match (self.it1.peek(), self.it2.peek()) { + (None, _) => self.it2.next(), + (_, None) => self.it1.next(), + (Some(a), Some(b)) => { + if a < b { + self.it1.next() + } else { + self.it2.next() + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let (min1, max1) = self.it1.size_hint(); + let (min2, max2) = self.it2.size_hint(); + let min = min1.saturating_add(min2); + let max = match (max1, max2) { + (Some(max1), Some(max2)) => max1.checked_add(max2), + _ => None, + }; + (min, max) + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// /// A version of [`Cow`] that can borrow from two different buffers, one of them From dcb210479b8e5e2102ac451b9cfbf8aba8b7bc50 Mon Sep 17 00:00:00 2001 From: Zachary Dremann Date: Wed, 11 Oct 2023 22:46:34 -0400 Subject: [PATCH 2/3] escape into a string directly, rather than bytes, then converting I believe this is quite a bit faster, because rust only has to verify that each string slicing operation starts/ends at character boundaries, none of the inner bytes need to be checked for UTF8-ness since they're already from a `&str` Also, when initially creating the escaped string, preallocate a little extra room, since we know the string will grow. --- src/escapei.rs | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/escapei.rs b/src/escapei.rs index 107d6871..a6efd221 100644 --- a/src/escapei.rs +++ b/src/escapei.rs @@ -110,23 +110,30 @@ where let mut escaped = None; let mut last_pos = 0; for i in escapes { - let escaped = escaped.get_or_insert_with(|| Vec::with_capacity(raw.len())); + // If we have an escape, the escaped string will be at least some larger than the raw string, + // reserve a little more space, so we might not resize at all if only a few escapes are found. + let escaped = escaped.get_or_insert_with(|| String::with_capacity(raw.len() + 64)); let byte = bytes[i]; - escaped.extend_from_slice(&bytes[last_pos..i]); + // SAFETY: the escapes iterator should only return indexes of bytes we know how to escape. + // if one of those bytes are found, it _must_ be a complete character, so `i` must be a + // character boundary. + // last_pos will only be either 0 or i+1, and all supported chars are one byte long, + // last_pos will also always be at a char boundary + escaped.push_str(&raw[last_pos..i]); match byte { - b'<' => escaped.extend_from_slice(b"<"), - b'>' => escaped.extend_from_slice(b">"), - b'\'' => escaped.extend_from_slice(b"'"), - b'&' => escaped.extend_from_slice(b"&"), - b'"' => escaped.extend_from_slice(b"""), + b'<' => escaped.push_str("<"), + b'>' => escaped.push_str(">"), + b'\'' => escaped.push_str("'"), + b'&' => escaped.push_str("&"), + b'"' => escaped.push_str("""), // This set of escapes handles characters that should be escaped // in elements of xs:lists, because those characters works as // delimiters of list elements - b'\t' => escaped.extend_from_slice(b" "), - b'\n' => escaped.extend_from_slice(b" "), - b'\r' => escaped.extend_from_slice(b" "), - b' ' => escaped.extend_from_slice(b" "), + b'\t' => escaped.push_str(" "), + b'\n' => escaped.push_str(" "), + b'\r' => escaped.push_str(" "), + b' ' => escaped.push_str(" "), _ => unreachable!( "Only '<', '>','\', '&', '\"', '\\t', '\\r', '\\n', and ' ' are escaped" ), @@ -135,14 +142,8 @@ where } if let Some(mut escaped) = escaped { - if let Some(raw) = bytes.get(last_pos..) { - escaped.extend_from_slice(raw); - } - // SAFETY: we operate on UTF-8 input and search for an one byte chars only, - // so all slices that was put to the `escaped` is a valid UTF-8 encoded strings - // TODO: Can be replaced with `unsafe { String::from_utf8_unchecked() }` - // if unsafe code will be allowed - Cow::Owned(String::from_utf8(escaped).unwrap()) + escaped.push_str(&raw[last_pos..]); + Cow::Owned(escaped) } else { Cow::Borrowed(raw) } @@ -182,17 +183,14 @@ where match iter.next() { Some(end) if bytes[end] == b';' => { // append valid data - if unescaped.is_none() { - unescaped = Some(String::with_capacity(raw.len())); - } - let unescaped = unescaped.as_mut().expect("initialized"); + let unescaped = unescaped.get_or_insert_with(|| String::with_capacity(raw.len())); unescaped.push_str(&raw[last_end..start]); // search for character correctness let pat = &raw[start + 1..end]; if let Some(entity) = pat.strip_prefix('#') { let codepoint = parse_number(entity, start..end)?; - unescaped.push_str(codepoint.encode_utf8(&mut [0u8; 4])); + unescaped.push(codepoint); } else if let Some(value) = named_entity(pat) { unescaped.push_str(value); } else if let Some(value) = resolve_entity(pat) { From c6ec23afb42c1d347c15788d501b8668c279f2d8 Mon Sep 17 00:00:00 2001 From: Zachary Dremann Date: Thu, 12 Oct 2023 18:28:30 -0400 Subject: [PATCH 3/3] Add unit tests for MergeIter --- src/utils.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/utils.rs b/src/utils.rs index 9719f88a..cc13c21d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -36,6 +36,7 @@ pub fn write_byte_string(f: &mut Formatter, byte_string: &[u8]) -> fmt::Result { Ok(()) } +/// An iterator that merges two sorted iterators into one sorted iterator pub(crate) struct MergeIter where It1: Iterator, @@ -284,4 +285,36 @@ mod tests { ]); assert_eq!(format!("{:?}", bytes), r##""Class IRI=\"#B\"""##); } + + #[test] + fn merge_empty() { + let iter = MergeIter::new(vec![].into_iter(), vec![].into_iter()); + assert_eq!(iter.collect::>(), vec![]); + } + + #[test] + fn merge_single_empty() { + let iter = MergeIter::new(vec![1].into_iter(), vec![].into_iter()); + assert_eq!(iter.collect::>(), vec![1]); + let iter = MergeIter::new(vec![].into_iter(), vec![1].into_iter()); + assert_eq!(iter.collect::>(), vec![1]); + } + + #[test] + fn merge_whole_side_before() { + let iter = MergeIter::new(vec![1, 2, 3].into_iter(), vec![4, 5, 6].into_iter()); + assert_eq!(iter.collect::>(), vec![1, 2, 3, 4, 5, 6]); + } + + #[test] + fn merge_interleave() { + let iter = MergeIter::new( + vec![1, 2, 8, 20].into_iter(), + vec![3, 4, 22, 23].into_iter(), + ); + assert_eq!( + iter.collect::>(), + vec![1, 2, 3, 4, 8, 20, 22, 23] + ); + } }