Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve EDNS OPT record in a response if truncation occurs #1364

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/proto/src/op/edns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ impl Edns {
pub fn set_option(&mut self, option: EdnsOption) {
self.options.insert(option);
}

pub(crate) fn len(&self) -> u16 {
1 // Name::root
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should consider adding some consts for these things, or perhaps instead use std::mem::size_of It would make the code a bit more self-documenting. Thoughts?

I know we have a similar practice elsewhere in the code, maybe I'll file a cleanup issue for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, could we use EncodedSize for this?
Let's do this in another PR anyway

+ 2 // RecordType::OPT
+ 2 // DNSClass
+ 4 // TTL (rcode_high, version and flags)
+ self.options.options().iter().map(|o| {
2 // Option code
+ 2 // Len field
+ o.1.len()
}).sum::<u16>()
}
}

impl<'a> From<&'a Record> for Edns {
Expand Down Expand Up @@ -208,6 +220,7 @@ impl BinEncodable for Edns {
opt::emit(encoder, &self.options)?;
let len = encoder.len_since_place(&place);
assert!(len <= u16::max_value() as usize);
debug_assert_eq!(len, self.len() as usize);

place.replace(encoder, len as u16)?;
Ok(())
Expand Down
24 changes: 20 additions & 4 deletions crates/proto/src/op/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

//! Basic protocol message for DNS

use std::iter;
use std::mem;
use std::ops::Deref;
use std::sync::Arc;
Expand Down Expand Up @@ -769,6 +768,23 @@ where
N: EmitAndCount,
D: EmitAndCount,
{
let real_encoder_max_size = encoder.max_size();
if let Some(edns) = edns {
// From RFC 6891 section-7:
// The minimal response MUST be the DNS header, question section, and an
// OPT record. This MUST also occur when a truncated response (using
// the DNS header's TC bit) is returned.
// Hence, we reserve some space for the EDNS OPT record
encoder.set_max_size(match real_encoder_max_size.checked_sub(edns.len()) {
Some(size) => size,
None => {
return Err(
ProtoErrorKind::MaxBufferSizeExceeded(real_encoder_max_size as usize).into(),
)
}
});
}

let include_sig0: bool = encoder.mode() != EncodeMode::Signing;
let place = encoder.place::<Header>()?;

Expand All @@ -780,10 +796,10 @@ where
let mut additional_count = count_was_truncated(additionals.emit(encoder))?;

if let Some(edns) = edns {
encoder.set_max_size(real_encoder_max_size);
// need to commit the error code
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this comment mean btw?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is referring to the fact that the response error code may need to be committed to the EDNS record... i.e. in EDNS the response (error code) is has high-order bits stored in the EDNS record that augment those in the original DNS Header. So I think that's what this may be talking about...

let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(edns))))?;
additional_count.0 += count.0;
additional_count.1 |= count.1;
Record::from(edns).emit(encoder)?;
additional_count.0 += 1;
}

// this is a little hacky, but if we are Verifying a signature, i.e. the original Message
Expand Down
19 changes: 17 additions & 2 deletions crates/proto/src/serialize/binary/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::op::Header;
// this is private to make sure there is no accidental access to the inner buffer.
mod private {
use crate::error::{ProtoErrorKind, ProtoResult};
use std::convert::TryFrom;

/// A wrapper for a buffer that guarantees writes never exceed a defined set of bytes
pub struct MaximalBuf<'a> {
Expand All @@ -43,6 +44,11 @@ mod private {
self.max_size = max as usize;
}

/// Returns the enforced maximum size
pub fn max_size(&mut self) -> u16 {
u16::try_from(self.max_size).expect("max_size is known to fit into u16")
}

/// returns an error if the maximum buffer size would be exceeded with the addition number of elements
///
/// and reserves the additional space in the buffer
Expand Down Expand Up @@ -145,6 +151,11 @@ impl<'a> BinEncoder<'a> {
self.buffer.set_max_size(max);
}

/// Returns the enforced maximum size of the buffer
pub(crate) fn max_size(&mut self) -> u16 {
self.buffer.max_size()
}

/// Returns a reference to the internal buffer
pub fn into_bytes(self) -> &'a Vec<u8> {
self.buffer.into_bytes()
Expand Down Expand Up @@ -400,8 +411,12 @@ impl<'a> BinEncoder<'a> {
let len = T::size_of();

// resize the buffer
self.buffer
.enforced_write(len, |buffer| buffer.resize(index + len, 0))?;
match (len + index).checked_sub(self.buffer.len()) {
None => {}
Some(resize_to) => self
.buffer
.enforced_write(resize_to, |buffer| buffer.resize(index + len, 0))?,
}

// update the offset
self.offset += len;
Expand Down
75 changes: 65 additions & 10 deletions crates/server/src/authority/message_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,7 @@ mod tests {
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(512);

let answer = Record::new()
.set_name(Name::from_str("www.example.com.").unwrap())
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34)))
.set_dns_class(DNSClass::NONE)
.clone();
let answer = make_example_record();

let message = MessageResponse {
header: Header::new(),
Expand Down Expand Up @@ -259,11 +255,7 @@ mod tests {
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(512);

let answer = Record::new()
.set_name(Name::from_str("www.example.com.").unwrap())
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34)))
.set_dns_class(DNSClass::NONE)
.clone();
let answer = make_example_record();

let message = MessageResponse {
header: Header::new(),
Expand All @@ -286,4 +278,67 @@ mod tests {
assert_eq!(response.answer_count(), 0);
assert!(response.name_server_count() > 1);
}

#[test]
fn test_edns_persists_in_truncated_message() {
let mut buf = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(512);

let answer = make_example_record();

let message = MessageResponse {
header: Header::new(),
queries: None,
answers: iter::repeat(&answer),
name_servers: iter::once(&answer),
soa: iter::once(&answer),
additionals: iter::once(&answer),
sig0: vec![],
edns: Some(Edns::new()),
};

message
.destructive_emit(&mut encoder)
.expect("failed to encode");
}

let response = Message::from_vec(&buf).expect("failed to decode");
assert!(response.header().truncated());
assert!(response.answer_count() > 1);
// should never have written the name server field...
assert_eq!(response.name_server_count(), 0);
assert!(response.edns().is_some());
}

#[test]
fn test_too_low_max_size() {
let mut buf = Vec::with_capacity(512);
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(2);

let answer = make_example_record();

let message = MessageResponse {
header: Header::new(),
queries: None,
answers: iter::repeat(&answer),
name_servers: iter::once(&answer),
soa: iter::once(&answer),
additionals: iter::once(&answer),
sig0: vec![],
edns: Some(Edns::new()),
};

assert!(message.destructive_emit(&mut encoder).is_err())
}

fn make_example_record() -> Record {
Record::new()
.set_name(Name::from_str("www.example.com.").unwrap())
.set_rdata(RData::A(Ipv4Addr::new(93, 184, 216, 34)))
.set_dns_class(DNSClass::NONE)
.clone()
}
}