diff --git a/crates/proto/src/error.rs b/crates/proto/src/error.rs index b9fbb02ec3..9866188967 100644 --- a/crates/proto/src/error.rs +++ b/crates/proto/src/error.rs @@ -26,6 +26,7 @@ use ring::error::Unspecified; use thiserror::Error; use crate::rr::{Name, RecordType}; +use crate::serialize::binary::DecodeError; #[cfg(feature = "backtrace")] lazy_static! { @@ -269,6 +270,12 @@ impl From for ProtoError { } } +impl From for ProtoError { + fn from(err: DecodeError) -> ProtoError { + ProtoErrorKind::Msg(err.to_string()).into() + } +} + impl From<&'static str> for ProtoError { fn from(msg: &'static str) -> ProtoError { ProtoErrorKind::Message(msg).into() diff --git a/crates/proto/src/rr/domain/name.rs b/crates/proto/src/rr/domain/name.rs index 693411af5f..a892b0bcc4 100644 --- a/crates/proto/src/rr/domain/name.rs +++ b/crates/proto/src/rr/domain/name.rs @@ -1128,7 +1128,7 @@ fn read_inner( label_data: &mut TinyVec<[u8; 32]>, label_ends: &mut TinyVec<[u8; 16]>, max_idx: Option, -) -> ProtoResult<()> { +) -> Result<(), DecodeError> { let mut state: LabelParseState = LabelParseState::LabelLengthOrPointer; let name_start = decoder.index(); @@ -1141,18 +1141,14 @@ fn read_inner( // this protects against overlapping labels if let Some(max_idx) = max_idx { if decoder.index() >= max_idx { - return Err(ProtoErrorKind::LabelOverlapsWithOther { - label: name_start, - other: max_idx, - } - .into()); + return Err(DecodeError::LabelOverlapsWithOther); } } // enforce max length of name let cur_len = label_data.len() + label_ends.len(); if cur_len > 255 { - return Err(ProtoErrorKind::DomainNameTooLong(cur_len).into()); + return Err(DecodeError::DomainNameTooLong); } state = match state { @@ -1165,15 +1161,15 @@ fn read_inner( Some(0) | None => LabelParseState::Root, Some(byte) if byte & 0b1100_0000 == 0b1100_0000 => LabelParseState::Pointer, Some(byte) if byte & 0b1100_0000 == 0b0000_0000 => LabelParseState::Label, - Some(byte) => return Err(ProtoErrorKind::UnrecognizedLabelCode(byte).into()), + Some(_) => return Err(DecodeError::UnrecognizedLabelCode), } } // labels must have a maximum length of 63 LabelParseState::Label => { let label = decoder - .read_character_data_max(Some(63))? + .read_character_data()? .verify_unwrap(|l| l.len() <= 63) - .map_err(|_| ProtoError::from("label exceeds maximum length of 63"))?; + .map_err(|_| DecodeError::LabelTooLong)?; label_data.extend_from_slice(label); label_ends.push(label_data.len() as u8); @@ -1203,7 +1199,6 @@ fn read_inner( // domain header). A zero offset specifies the first byte of the ID field, // etc. LabelParseState::Pointer => { - let pointer_location = decoder.index(); let location = decoder .read_u16()? .map(|u| { @@ -1214,12 +1209,7 @@ fn read_inner( // all labels must appear "prior" to this Name (*ptr as usize) < name_start }) - .map_err(|e| { - ProtoError::from(ProtoErrorKind::PointerNotPriorToLabel { - idx: pointer_location, - ptr: e, - }) - })?; + .map_err(|_| DecodeError::PointerNotPriorToLabel)?; let mut pointer = decoder.clone(location); read_inner(&mut pointer, label_data, label_ends, Some(name_start))?; diff --git a/crates/proto/src/rr/record_type.rs b/crates/proto/src/rr/record_type.rs index c041e8d78e..fbf106bcc9 100644 --- a/crates/proto/src/rr/record_type.rs +++ b/crates/proto/src/rr/record_type.rs @@ -249,12 +249,12 @@ impl BinEncodable for RecordType { impl<'r> BinDecodable<'r> for RecordType { fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult { - decoder + Ok(decoder .read_u16() .map( Restrict::unverified, /*RecordType is safe with any u16*/ ) - .map(Self::from) + .map(Self::from)?) } } diff --git a/crates/proto/src/serialize/binary/bin_tests.rs b/crates/proto/src/serialize/binary/bin_tests.rs index 31eefb4cbf..a32f6b3eb9 100644 --- a/crates/proto/src/serialize/binary/bin_tests.rs +++ b/crates/proto/src/serialize/binary/bin_tests.rs @@ -58,7 +58,7 @@ fn get_u16_data() -> Vec<(u16, Vec)> { #[test] fn read_u16() { test_read_data_set(get_u16_data(), |mut d| { - d.read_u16().map(Restrict::unverified) + d.read_u16().map(Restrict::unverified).map_err(Into::into) }); } @@ -83,7 +83,7 @@ fn get_i32_data() -> Vec<(i32, Vec)> { #[test] fn read_i32() { test_read_data_set(get_i32_data(), |mut d| { - d.read_i32().map(Restrict::unverified) + d.read_i32().map(Restrict::unverified).map_err(Into::into) }); } @@ -109,7 +109,7 @@ fn get_u32_data() -> Vec<(u32, Vec)> { #[test] fn read_u32() { test_read_data_set(get_u32_data(), |mut d| { - d.read_u32().map(Restrict::unverified) + d.read_u32().map(Restrict::unverified).map_err(Into::into) }); } diff --git a/crates/proto/src/serialize/binary/decoder.rs b/crates/proto/src/serialize/binary/decoder.rs index 6ae6373fbc..aae06dab74 100644 --- a/crates/proto/src/serialize/binary/decoder.rs +++ b/crates/proto/src/serialize/binary/decoder.rs @@ -14,8 +14,8 @@ * limitations under the License. */ -use crate::error::{ProtoError, ProtoErrorKind, ProtoResult}; use crate::serialize::binary::Restrict; +use thiserror::Error; /// This is non-destructive to the inner buffer, b/c for pointer types we need to perform a reverse /// seek to lookup names @@ -26,7 +26,42 @@ use crate::serialize::binary::Restrict; /// binary DNS protocols. pub struct BinDecoder<'a> { buffer: &'a [u8], - index: usize, + remaining: &'a [u8], +} + +pub type DecodeResult = Result; + +/// And error that can occur deep in a decoder +/// This type is kept very small so that function that use it inline often +#[derive(Debug, Error)] +pub enum DecodeError { + /// Insufficient data in the buffer for a read operation + #[error("unexpected end of input reached")] + InsufficientBytes, + + /// slice_from was called with an invalid index + #[error("index antecedes upper bound")] + InvalidPreviousIndex, + + /// Pointer points to an index within or after the current label + #[error("pointer does not point to an index before the current label")] + PointerNotPriorToLabel, + + /// Unreachable + #[error("label exceeds length of 63")] + LabelTooLong, + + /// Invalid code for a label + #[error("the start of a label must be a pointer or a length less than 64")] + UnrecognizedLabelCode, + + /// Domain names cannot exceed 255 + #[error("domain name exceeds maximum length")] + DomainNameTooLong, + + /// Labels may not overlap + #[error("label overlaps with other")] + LabelOverlapsWithOther, } impl<'a> BinDecoder<'a> { @@ -35,18 +70,23 @@ impl<'a> BinDecoder<'a> { /// # Arguments /// /// * `buffer` - buffer from which all data will be read + #[inline] pub fn new(buffer: &'a [u8]) -> Self { - BinDecoder { buffer, index: 0 } + BinDecoder { + buffer, + remaining: buffer, + } } /// Pop one byte from the buffer - pub fn pop(&mut self) -> ProtoResult> { - if self.index < self.buffer.len() { - let byte = self.buffer[self.index]; - self.index += 1; + pub fn pop(&mut self) -> DecodeResult> { + if self.remaining.len() > 0 { + let (byte, remaining) = self.remaining.split_at(1); + self.remaining = remaining; + let byte = byte[0]; Ok(Restrict::new(byte)) } else { - Err("unexpected end of input reached".into()) + Err(DecodeError::InsufficientBytes) } } @@ -62,7 +102,7 @@ impl<'a> BinDecoder<'a> { /// assert_eq!(decoder.len(), 1); /// ``` pub fn len(&self) -> usize { - self.buffer.len().saturating_sub(self.index) + self.remaining.len() } /// Returns `true` if the buffer is empty @@ -72,16 +112,12 @@ impl<'a> BinDecoder<'a> { /// Peed one byte forward, without moving the current index forward pub fn peek(&self) -> Option> { - if self.index < self.buffer.len() { - Some(Restrict::new(self.buffer[self.index])) - } else { - None - } + Some(Restrict::new(*self.remaining.get(0)?)) } /// Returns the current index in the buffer pub fn index(&self) -> usize { - self.index + self.buffer.len() - self.remaining.len() } /// This is a pretty efficient clone, as the buffer is never cloned, and only the index is set @@ -89,7 +125,7 @@ impl<'a> BinDecoder<'a> { pub fn clone(&self, index_at: u16) -> BinDecoder<'a> { BinDecoder { buffer: self.buffer, - index: index_at as usize, + remaining: &self.buffer[index_at as usize..], } } @@ -105,32 +141,8 @@ impl<'a> BinDecoder<'a> { /// # Returns /// /// A String version of the character data - pub fn read_character_data(&mut self) -> ProtoResult> { - self.read_character_data_max(None) - } - - /// Reads to a maximum length of data, returns an error if this is exceeded - pub fn read_character_data_max( - &mut self, - max_len: Option, - ) -> ProtoResult> { - let length = self - .pop()? - .map(|u| u as usize) - .verify_unwrap(|length| { - if let Some(max_len) = max_len { - *length <= max_len - } else { - true - } - }) - .map_err(|length| { - ProtoError::from(ProtoErrorKind::CharacterDataTooLong { - max: max_len.unwrap_or_default(), - len: length, - }) - })?; - + pub fn read_character_data(&mut self) -> DecodeResult> { + let length = self.pop()?.unverified() as usize; self.read_slice(length) } @@ -143,7 +155,7 @@ impl<'a> BinDecoder<'a> { /// # Returns /// /// The Vec of the specified length, otherwise an error - pub fn read_vec(&mut self, len: usize) -> ProtoResult>> { + pub fn read_vec(&mut self, len: usize) -> DecodeResult>> { self.read_slice(len).map(|s| s.map(ToOwned::to_owned)) } @@ -156,30 +168,26 @@ impl<'a> BinDecoder<'a> { /// # Returns /// /// The slice of the specified length, otherwise an error - pub fn read_slice(&mut self, len: usize) -> ProtoResult> { - let end = self - .index - .checked_add(len) - .ok_or_else(|| ProtoError::from("invalid length for slice"))?; - if end > self.buffer.len() { - return Err("buffer exhausted".into()); + pub fn read_slice(&mut self, len: usize) -> DecodeResult> { + if len > self.remaining.len() { + return Err(DecodeError::InsufficientBytes); } - let slice: &'a [u8] = &self.buffer[self.index..end]; - self.index += len; - Ok(Restrict::new(slice)) + let (read, remaining) = self.remaining.split_at(len); + self.remaining = remaining; + Ok(Restrict::new(read)) } /// Reads a slice from a previous index to the current - pub fn slice_from(&self, index: usize) -> ProtoResult<&'a [u8]> { - if index > self.index { - return Err("index antecedes upper bound".into()); + pub fn slice_from(&self, index: usize) -> DecodeResult<&'a [u8]> { + if index > self.index() { + return Err(DecodeError::InvalidPreviousIndex); } - Ok(&self.buffer[index..self.index]) + Ok(&self.buffer[index..self.index()]) } /// Reads a byte from the buffer, equivalent to `Self::pop()` - pub fn read_u8(&mut self) -> ProtoResult> { + pub fn read_u8(&mut self) -> DecodeResult> { self.pop() } @@ -191,7 +199,7 @@ impl<'a> BinDecoder<'a> { /// # Return /// /// Return the u16 from the buffer - pub fn read_u16(&mut self) -> ProtoResult> { + pub fn read_u16(&mut self) -> DecodeResult> { Ok(self .read_slice(2)? .map(|s| u16::from_be_bytes([s[0], s[1]]))) @@ -205,10 +213,11 @@ impl<'a> BinDecoder<'a> { /// # Return /// /// Return the i32 from the buffer - pub fn read_i32(&mut self) -> ProtoResult> { - Ok(self - .read_slice(4)? - .map(|s| i32::from_be_bytes([s[0], s[1], s[2], s[3]]))) + pub fn read_i32(&mut self) -> DecodeResult> { + Ok(self.read_slice(4)?.map(|s| { + assert!(s.len() == 4); + i32::from_be_bytes([s[0], s[1], s[2], s[3]]) + })) } /// Reads the next four bytes into u32. @@ -219,10 +228,11 @@ impl<'a> BinDecoder<'a> { /// # Return /// /// Return the u32 from the buffer - pub fn read_u32(&mut self) -> ProtoResult> { - Ok(self - .read_slice(4)? - .map(|s| u32::from_be_bytes([s[0], s[1], s[2], s[3]]))) + pub fn read_u32(&mut self) -> DecodeResult> { + Ok(self.read_slice(4)?.map(|s| { + assert!(s.len() == 4); + u32::from_be_bytes([s[0], s[1], s[2], s[3]]) + })) } } diff --git a/crates/proto/src/serialize/binary/mod.rs b/crates/proto/src/serialize/binary/mod.rs index 05223dc8c8..1253c87b07 100644 --- a/crates/proto/src/serialize/binary/mod.rs +++ b/crates/proto/src/serialize/binary/mod.rs @@ -20,7 +20,7 @@ mod decoder; mod encoder; mod restrict; -pub use self::decoder::BinDecoder; +pub use self::decoder::{BinDecoder, DecodeError}; pub use self::encoder::BinEncoder; pub use self::encoder::EncodeMode; pub use self::restrict::{Restrict, RestrictedMath, Verified}; @@ -67,7 +67,10 @@ impl BinEncodable for u16 { impl<'r> BinDecodable<'r> for u16 { fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult { - decoder.read_u16().map(Restrict::unverified) + decoder + .read_u16() + .map(Restrict::unverified) + .map_err(Into::into) } } @@ -79,7 +82,10 @@ impl BinEncodable for i32 { impl<'r> BinDecodable<'r> for i32 { fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult { - decoder.read_i32().map(Restrict::unverified) + decoder + .read_i32() + .map(Restrict::unverified) + .map_err(Into::into) } } @@ -91,7 +97,10 @@ impl BinEncodable for u32 { impl<'r> BinDecodable<'r> for u32 { fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult { - decoder.read_u32().map(Restrict::unverified) + decoder + .read_u32() + .map(Restrict::unverified) + .map_err(Into::into) } }