From 4d29b04c1762f65b649e32e889fdaf582cfcd0c7 Mon Sep 17 00:00:00 2001 From: Ben Kimock Date: Wed, 3 Mar 2021 15:10:19 -0500 Subject: [PATCH] Reimplement Name using a pair of SmallVec, remove Index impls --- crates/client/src/rr/lower_name.rs | 11 +- crates/proto/src/rr/domain/name.rs | 219 ++++++++++++++++++----------- 2 files changed, 141 insertions(+), 89 deletions(-) diff --git a/crates/client/src/rr/lower_name.rs b/crates/client/src/rr/lower_name.rs index b0b98b3e59..a7bea3ccf1 100644 --- a/crates/client/src/rr/lower_name.rs +++ b/crates/client/src/rr/lower_name.rs @@ -11,14 +11,13 @@ use std::borrow::Borrow; use std::cmp::{Ordering, PartialEq}; use std::fmt; use std::hash::{Hash, Hasher}; -use std::ops::Index; use std::str::FromStr; use crate::proto::error::*; #[cfg(feature = "serde-config")] use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use crate::rr::{Label, Name}; +use crate::rr::Name; use crate::serialize::binary::*; /// them should be through references. As a workaround the Strings are all Rc as well as the array @@ -191,14 +190,6 @@ impl fmt::Display for LowerName { } } -impl Index for LowerName { - type Output = Label; - - fn index(&self, _index: usize) -> &Label { - &(self.0[_index]) - } -} - impl PartialOrd for LowerName { fn partial_cmp(&self, other: &LowerName) -> Option { Some(self.cmp(other)) diff --git a/crates/proto/src/rr/domain/name.rs b/crates/proto/src/rr/domain/name.rs index 280da4f29f..d51aad179e 100644 --- a/crates/proto/src/rr/domain/name.rs +++ b/crates/proto/src/rr/domain/name.rs @@ -7,14 +7,11 @@ //! domain name, aka labels, implementation -use std::borrow::Borrow; use std::char; use std::cmp::{Ordering, PartialEq}; use std::fmt::{self, Write}; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::ops::Index; -use std::slice::Iter; use std::str::FromStr; use crate::error::*; @@ -26,11 +23,12 @@ use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use smallvec::SmallVec; -/// Them should be through references. As a workaround the Strings are all Rc as well as the array +/// A domain name #[derive(Clone, Default, Debug, Eq)] pub struct Name { is_fqdn: bool, - labels: SmallVec<[Label; 4]>, + label_data: SmallVec<[u8; 32]>, + label_ends: SmallVec<[u8; 16]>, } impl Name { @@ -59,7 +57,7 @@ impl Name { /// assert_eq!(&root.to_string(), "."); /// ``` pub fn is_root(&self) -> bool { - self.labels.is_empty() && self.is_fqdn() + self.label_ends.is_empty() && self.is_fqdn() } /// Returns true if the name is a fully qualified domain name. @@ -97,7 +95,12 @@ impl Name { /// Returns an iterator over the labels pub fn iter(&self) -> LabelIter<'_> { - LabelIter(self.labels.iter()) + LabelIter { + name: self, + index: 0, + started: false, + finished: false, + } } /// Appends the label to the end of this name @@ -113,8 +116,10 @@ impl Name { /// assert_eq!(name, Name::from_str("www.example.com").unwrap()); /// ``` pub fn append_label(mut self, label: L) -> ProtoResult { - self.labels.push(label.into_label()?); - if self.labels.len() > 255 { + self.label_data + .extend_from_slice(label.into_label()?.as_bytes()); + self.label_ends.push(self.label_data.len() as u8); + if self.len() > 255 { return Err("labels exceed maximum length of 255".into()); }; Ok(self) @@ -138,7 +143,7 @@ impl Name { /// /// // Force a set of bytes into labels (this is none-standard and potentially dangerous) /// let from_labels = Name::from_labels(vec!["bad chars".as_bytes(), "example".as_bytes(), "com".as_bytes()]).unwrap(); - /// assert_eq!(from_labels[0].as_bytes(), "bad chars".as_bytes()); + /// assert_eq!(from_labels.iter().next(), Some(&b"bad chars"[..])); /// /// let root = Name::from_labels(Vec::<&str>::new()).unwrap(); /// assert!(root.is_root()); @@ -152,7 +157,7 @@ impl Name { .into_iter() .map(IntoLabel::into_label) .partition(Result::is_ok); - let labels: SmallVec<_> = labels.into_iter().map(Result::unwrap).collect(); + let labels: Vec<_> = labels.into_iter().map(Result::unwrap).collect(); let errors: Vec<_> = errors.into_iter().map(Result::unwrap_err).collect(); if labels.len() > 255 { @@ -162,9 +167,17 @@ impl Name { return Err(format!("error converting some labels: {:?}", errors).into()); }; + let mut label_ends = SmallVec::new(); + let mut label_data = SmallVec::new(); + for label in labels { + label_data.extend_from_slice(label.as_bytes()); + label_ends.push(label_data.len() as u8); + } + Ok(Name { is_fqdn: true, - labels, + label_data, + label_ends, }) } @@ -194,9 +207,9 @@ impl Name { /// assert!(name.is_fqdn()); /// ``` pub fn append_name(mut self, other: &Self) -> Self { - self.labels.reserve_exact(other.labels.len()); - for label in &other.labels { - self.labels.push(label.clone()); + for label in other.iter() { + self.label_data.extend_from_slice(label); + self.label_ends.push(self.label_data.len() as u8); } self.is_fqdn = other.is_fqdn; @@ -241,14 +254,15 @@ impl Name { /// assert!(example_com.to_lowercase().eq_case(&Name::from_str("example.com").unwrap())); /// ``` pub fn to_lowercase(&self) -> Self { - let mut new_labels = SmallVec::with_capacity(self.labels.len()); - for label in &self.labels { - new_labels.push(label.to_lowercase()) - } - + let new_label_data = self + .label_data + .iter() + .map(|c| c.to_ascii_lowercase()) + .collect(); Name { is_fqdn: self.is_fqdn, - labels: new_labels, + label_data: new_label_data, + label_ends: self.label_ends.clone(), } } @@ -266,7 +280,7 @@ impl Name { /// assert_eq!(Name::root().base_name(), Name::root()); /// ``` pub fn base_name(&self) -> Name { - let length = self.labels.len(); + let length = self.label_ends.len(); if length > 0 { return self.trim_to(length - 1); } @@ -288,21 +302,17 @@ impl Name { /// assert_eq!(example_com.trim_to(3), Name::from_str("example.com.").unwrap()); /// ``` pub fn trim_to(&self, num_labels: usize) -> Name { - if self.labels.len() >= num_labels { - let trim = self.labels.len() - num_labels; - Name { - is_fqdn: self.is_fqdn, - labels: SmallVec::from(&self.labels[trim..]), - } - } else { + if num_labels > self.label_ends.len() { self.clone() + } else { + Name::from_labels(self.iter().skip(self.label_ends.len() - num_labels)).unwrap() } } /// same as `zone_of` allows for case sensitive call pub fn zone_of_case(&self, name: &Self) -> bool { - let self_len = self.labels.len(); - let name_len = name.labels.len(); + let self_len = self.label_ends.len(); + let name_len = name.label_ends.len(); if self_len == 0 { return true; } @@ -371,11 +381,11 @@ impl Name { pub fn num_labels(&self) -> u8 { // it is illegal to have more than 256 labels. - let num = self.labels.len() as u8; + let num = self.label_ends.len() as u8; - self.labels - .first() - .map(|l| if l.is_wildcard() { num - 1 } else { num }) + self.iter() + .next() + .map(|l| if l == b"*" { num - 1 } else { num }) .unwrap_or(num) } @@ -395,12 +405,12 @@ impl Name { /// assert_eq!(Name::root().len(), 1); /// ``` pub fn len(&self) -> usize { - let dots = if !self.labels.is_empty() { - self.labels.len() + let dots = if !self.label_ends.is_empty() { + self.label_ends.len() } else { 1 }; - self.labels.iter().fold(dots, |acc, item| acc + item.len()) + dots + self.label_data.len() } /// Returns whether the length of the labels, in bytes is 0. In practice, since '.' counts as @@ -419,7 +429,7 @@ impl Name { /// /// let name = Name::from_str("example.com.").unwrap(); /// assert_eq!(name.base_name(), Name::from_str("com.").unwrap()); - /// assert_eq!(name[0].to_string(), "example"); + /// assert_eq!(name.iter().next(), Some(&b"example"[..])); /// ``` pub fn parse(local: &str, origin: Option<&Self>) -> ProtoResult { Self::from_encoded_str::(local, origin) @@ -521,7 +531,7 @@ impl Name { match state { ParseState::Label => match ch { '.' => { - name.labels.push(E::to_label(&label)?); + name = name.append_label(E::to_label(&label)?)?; label.clear(); } '\\' => state = ParseState::Escape1, @@ -570,7 +580,7 @@ impl Name { } if !label.is_empty() { - name.labels.push(E::to_label(&label)?); + name = name.append_label(E::to_label(&label)?)?; } if local.ends_with('.') { @@ -593,10 +603,10 @@ impl Name { let buf_len = encoder.len(); // lazily assert the size is less than 255... // lookup the label in the BinEncoder // if it exists, write the Pointer - let labels: &[Label] = &self.labels; + let labels = self.iter(); // start index of each label - let mut labels_written: Vec = Vec::with_capacity(labels.len()); + let mut labels_written: Vec = Vec::with_capacity(self.label_ends.len()); if canonical { for label in labels { @@ -676,22 +686,24 @@ impl Name { /// compares with the other label, ignoring case fn cmp_with_f(&self, other: &Self) -> Ordering { - if self.labels.is_empty() && other.labels.is_empty() { + if self.label_ends.is_empty() && other.label_ends.is_empty() { return Ordering::Equal; } // we reverse the iters so that we are comparing from the root/domain to the local... - let self_labels = self.labels.iter().rev(); - let other_labels = other.labels.iter().rev(); + let self_labels = self.iter().rev(); + let other_labels = other.iter().rev(); for (l, r) in self_labels.zip(other_labels) { - match l.cmp_with_f::(r) { + let l = Label::from_raw_bytes(l).unwrap(); + let r = Label::from_raw_bytes(r).unwrap(); + match l.cmp_with_f::(&r) { Ordering::Equal => continue, not_eq => return not_eq, } } - self.labels.len().cmp(&other.labels.len()) + self.label_ends.len().cmp(&other.label_ends.len()) } /// Case sensitive comparison @@ -785,14 +797,14 @@ impl Name { } fn write_labels(&self, f: &mut W) -> Result<(), fmt::Error> { - let mut iter = self.labels.iter(); + let mut iter = self.iter().map(|b| Label::from_raw_bytes(b).unwrap()); if let Some(label) = iter.next() { - E::write_label(f, label)?; + E::write_label(f, &label)?; } for label in iter { write!(f, ".")?; - E::write_label(f, label)?; + E::write_label(f, &label)?; } // if it was the root name @@ -841,7 +853,7 @@ impl Name { /// assert!(!name.is_wildcard()); /// ``` pub fn is_wildcard(&self) -> bool { - self.labels.first().map_or(false, Label::is_wildcard) + self.iter().next().map_or(false, |l| l == b"*") } /// Converts a name to a wildcard, by replacing the first label with `*` @@ -859,13 +871,22 @@ impl Name { /// let name = Name::root().into_wildcard(); /// assert_eq!(name, Name::root()); /// ``` - pub fn into_wildcard(mut self) -> Self { - let wildcard = Label::wildcard(); - if let Some(first) = self.labels.first_mut() { - *first = wildcard; + pub fn into_wildcard(self) -> Self { + if self.label_ends.is_empty() { + Name::root() + } else { + let mut label_data = smallvec::smallvec![b'*']; + let mut label_ends = smallvec::smallvec![1]; + for label in self.iter().skip(1) { + label_data.extend_from_slice(label); + label_ends.push(label_data.len() as u8); + } + Name { + label_data, + label_ends, + is_fqdn: self.is_fqdn, + } } - - self } } @@ -900,20 +921,63 @@ impl LabelEnc for LabelEncUtf8 { } /// An iterator over labels in a name -pub struct LabelIter<'a>(Iter<'a, Label>); +pub struct LabelIter<'a> { + name: &'a Name, + index: usize, + started: bool, + finished: bool, +} impl<'a> Iterator for LabelIter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { - self.0.next().map(Borrow::borrow) + if self.finished { + return None; + } + self.started = true; + let end = *self.name.label_ends.get(self.index)?; + let start = if self.index == 0 { + 0 + } else { + self.name.label_ends[self.index - 1] + }; + self.index += 1; + if self.index == self.name.label_ends.len() { + self.finished = true; + } + Some(&self.name.label_data[start as usize..end as usize]) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.name.label_ends.len() - self.index; + (len, Some(len)) } } impl<'a> ExactSizeIterator for LabelIter<'a> {} + impl<'a> DoubleEndedIterator for LabelIter<'a> { fn next_back(&mut self) -> Option { - self.0.next_back().map(Borrow::borrow) + if self.finished { + return None; + } + if !self.started { + self.index = self.name.label_ends.len().checked_sub(1)?; + } + self.started = true; + let end = *self.name.label_ends.get(self.index)?; + let start = if self.index == 0 { + 0 + } else { + self.name.label_ends[self.index - 1] + }; + if self.index == 0 { + self.finished = true; + } else { + self.index -= 1; + } + Some(&self.name.label_data[start as usize..end as usize]) } } @@ -1017,7 +1081,10 @@ impl Hash for Name { self.is_fqdn.hash(state); // this needs to be CaseInsensitive like PartialEq - for l in self.labels.iter().map(Label::to_lowercase) { + for l in self + .iter() + .map(|l| Label::from_raw_bytes(l).unwrap().to_lowercase()) + { l.hash(state); } } @@ -1043,23 +1110,25 @@ impl<'r> BinDecodable<'r> for Name { /// all names will be stored lowercase internally. /// This will consume the portions of the `Vec` which it is reading... fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult { - let mut labels = SmallVec::new(); - read_inner(decoder, &mut labels, None)?; + let mut label_data = SmallVec::new(); + let mut label_ends = SmallVec::new(); + read_inner(decoder, &mut label_data, &mut label_ends, None)?; Ok(Name { is_fqdn: true, - labels, + label_data, + label_ends, }) } } fn read_inner( decoder: &mut BinDecoder<'_>, - labels: &mut SmallVec<[Label; 4]>, + label_data: &mut SmallVec<[u8; 32]>, + label_ends: &mut SmallVec<[u8; 16]>, max_idx: Option, ) -> ProtoResult<()> { let mut state: LabelParseState = LabelParseState::LabelLengthOrPointer; let name_start = decoder.index(); - let mut run_len: usize = labels.iter().map(Label::len).sum(); // assume all chars are utf-8. We're doing byte-by-byte operations, no endianess issues... // reserved: (1000 0000 aka 0800) && (0100 0000 aka 0400) @@ -1079,7 +1148,7 @@ fn read_inner( } // enforce max length of name - let cur_len = run_len + labels.len(); + let cur_len = label_data.len() + label_ends.len(); if cur_len > 255 { return Err(ProtoErrorKind::DomainNameTooLong(cur_len).into()); } @@ -1104,8 +1173,8 @@ fn read_inner( .verify_unwrap(|l| l.len() <= 63) .map_err(|_| ProtoError::from("label exceeds maximum length of 63"))?; - run_len += label.len(); - labels.push(Label::from_raw_bytes(label).unwrap()); + label_data.extend_from_slice(label); + label_ends.push(label_data.len() as u8); // reset to collect more data LabelParseState::LabelLengthOrPointer @@ -1151,7 +1220,7 @@ fn read_inner( })?; let mut pointer = decoder.clone(location); - read_inner(&mut pointer, labels, Some(name_start))?; + read_inner(&mut pointer, label_data, label_ends, Some(name_start))?; // Pointers always finish the name, break like Root. break; @@ -1173,14 +1242,6 @@ impl fmt::Display for Name { } } -impl Index for Name { - type Output = Label; - - fn index(&self, _index: usize) -> &Label { - &self.labels[_index] - } -} - impl PartialOrd for Name { fn partial_cmp(&self, other: &Name) -> Option { Some(self.cmp(other))