Skip to content

Commit

Permalink
proto: add into_parts methods (#1397)
Browse files Browse the repository at this point in the history
* proto: add into_parts methods

* proto: deprecate options() for OPT

* proto: use XParts types in public API

* proto: fix cargo destructure

* proto: use cfg_if

* proto: fix lint"
"

* proto: attempt to satisfy CI

* add clone to digest
  • Loading branch information
leshow committed Mar 8, 2021
1 parent 61122bc commit a44d441
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 16 deletions.
6 changes: 3 additions & 3 deletions crates/client/src/error/dnssec_error.rs
Expand Up @@ -161,7 +161,7 @@ impl From<SslErrorStack> for Error {
pub mod not_openssl {
use std;

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct SslErrorStack;

impl std::fmt::Display for SslErrorStack {
Expand All @@ -181,10 +181,10 @@ pub mod not_openssl {
pub mod not_ring {
use std;

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct KeyRejected;

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct Unspecified;

impl std::fmt::Display for KeyRejected {
Expand Down
4 changes: 2 additions & 2 deletions crates/proto/src/error.rs
Expand Up @@ -332,7 +332,7 @@ pub mod not_openssl {
use std;

/// SslErrorStac stub
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct SslErrorStack;

impl std::fmt::Display for SslErrorStack {
Expand All @@ -354,7 +354,7 @@ pub mod not_ring {
use std;

/// The Unspecified error replacement
#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
pub struct Unspecified;

impl std::fmt::Display for Unspecified {
Expand Down
2 changes: 2 additions & 0 deletions crates/proto/src/lib.rs
Expand Up @@ -8,6 +8,7 @@

#![warn(
missing_docs,
missing_copy_implementations,
clippy::dbg_macro,
clippy::print_stdout,
clippy::unimplemented
Expand Down Expand Up @@ -194,6 +195,7 @@ pub trait Time {

/// New type which is implemented using tokio::time::{Delay, Timeout}
#[cfg(any(test, feature = "tokio-runtime"))]
#[derive(Debug, Clone, Copy)]
pub struct TokioTime;

#[cfg(any(test, feature = "tokio-runtime"))]
Expand Down
2 changes: 1 addition & 1 deletion crates/proto/src/op/header.rs
Expand Up @@ -55,7 +55,7 @@ use crate::serialize::binary::*;
///
/// ```
///
#[derive(Clone, Debug, PartialEq, PartialOrd)]
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
pub struct Header {
id: u16,
message_type: MessageType,
Expand Down
57 changes: 56 additions & 1 deletion crates/proto/src/op/message.rs
Expand Up @@ -95,7 +95,8 @@ pub fn update_header_counts(
assert!(counts.nameserver_count <= u16::max_value() as usize);
assert!(counts.additional_count <= u16::max_value() as usize);

let mut header = current_header.clone();
// TODO: should the function just take by value?
let mut header = *current_header;
header
.set_query_count(counts.query_count as u16)
.set_answer_count(counts.answer_count as u16)
Expand All @@ -109,6 +110,7 @@ pub fn update_header_counts(
/// Tracks the counts of the records in the Message.
///
/// This is only used internally during serialization.
#[derive(Debug, Copy, Clone)]
pub struct HeaderCounts {
/// The number of queries in the Message
pub query_count: usize,
Expand Down Expand Up @@ -678,6 +680,58 @@ impl Message {

Ok(())
}

/// Consumes `Message` and returns into components
pub fn into_parts(self) -> MessageParts {
self.into()
}
}

/// Consumes `Message` giving public access to fields in `Message` so they can be
/// destructured and taken by value
/// ```rust
/// let msg = Message::new();
/// let MessageParts { queries, .. } = msg.into_parts();
/// ```
#[derive(Clone, Debug, PartialEq, Default)]
pub struct MessageParts {
/// message header
pub header: Header,
/// message queries
pub queries: Vec<Query>,
/// message answers
pub answers: Vec<Record>,
/// message name_servers
pub name_servers: Vec<Record>,
/// message additional records
pub additionals: Vec<Record>,
/// sig0
pub sig0: Vec<Record>,
/// optional edns records
pub edns: Option<Edns>,
}

impl From<Message> for MessageParts {
fn from(msg: Message) -> Self {
let Message {
header,
queries,
answers,
name_servers,
additionals,
sig0,
edns,
} = msg;
MessageParts {
header,
queries,
answers,
name_servers,
additionals,
sig0,
edns,
}
}
}

impl Deref for Message {
Expand Down Expand Up @@ -709,6 +763,7 @@ pub trait MessageFinalizer: Send + Sync + 'static {
/// A MessageFinalizer which does nothing
///
/// *WARNING* This should only be used in None context, it will panic in all cases where finalize is called.
#[derive(Debug, Clone, Copy)]
pub struct NoopMessageFinalizer;

impl NoopMessageFinalizer {
Expand Down
49 changes: 49 additions & 0 deletions crates/proto/src/op/query.rs
Expand Up @@ -159,6 +159,55 @@ impl Query {
pub fn mdns_unicast_response(&self) -> bool {
self.mdns_unicast_response
}

/// Consumes `Query` and returns it's components
pub fn into_parts(self) -> QueryParts {
self.into()
}
}

/// Consumes `Query` giving public access to fields of `Query` so they can
/// be destructured and taken by value.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct QueryParts {
/// QNAME
pub name: Name,
/// QTYPE
pub query_type: RecordType,
/// QCLASS
pub query_class: DNSClass,
/// mDNS unicast-response bit set or not
#[cfg(feature = "mdns")]
pub mdns_unicast_response: bool,
}

impl From<Query> for QueryParts {
fn from(q: Query) -> Self {
cfg_if::cfg_if! {
if #[cfg(feature = "mdns")] {
let Query {
name,
query_type,
query_class,
mdns_unicast_response,
} = q;
} else {
let Query {
name,
query_type,
query_class,
} = q;
}
}

Self {
name,
query_type,
query_class,
#[cfg(feature = "mdns")]
mdns_unicast_response,
}
}
}

impl BinEncodable for Query {
Expand Down
1 change: 1 addition & 0 deletions crates/proto/src/rr/dnssec/ec_public_key.rs
Expand Up @@ -8,6 +8,7 @@
use super::Algorithm;
use crate::error::*;

#[derive(Copy, Clone)]
pub struct ECPublicKey {
buf: [u8; MAX_LEN],
len: usize,
Expand Down
1 change: 1 addition & 0 deletions crates/proto/src/rr/dnssec/mod.rs
Expand Up @@ -49,6 +49,7 @@ pub use ring::digest::Digest;

/// This is an empty type, enable Ring or OpenSSL for this feature
#[cfg(not(any(feature = "openssl", feature = "ring")))]
#[derive(Copy, Clone, Debug)]
pub struct Digest;

#[cfg(not(any(feature = "openssl", feature = "ring")))]
Expand Down
15 changes: 14 additions & 1 deletion crates/proto/src/rr/rdata/opt.rs
Expand Up @@ -181,6 +181,7 @@ impl OPT {
OPT { options }
}

#[deprecated(note = "Please use as_ref() or as_mut() for shared/mutable references")]
/// The entire map of options
pub fn options(&self) -> &HashMap<EdnsCode, EdnsOption> {
&self.options
Expand All @@ -202,6 +203,18 @@ impl OPT {
}
}

impl AsMut<HashMap<EdnsCode, EdnsOption>> for OPT {
fn as_mut(&mut self) -> &mut HashMap<EdnsCode, EdnsOption> {
&mut self.options
}
}

impl AsRef<HashMap<EdnsCode, EdnsOption>> for OPT {
fn as_ref(&self) -> &HashMap<EdnsCode, EdnsOption> {
&self.options
}
}

/// Read the RData from the given Decoder
pub fn read(decoder: &mut BinDecoder<'_>, rdata_length: Restrict<u16>) -> ProtoResult<OPT> {
let mut state: OptReadState = OptReadState::ReadCode;
Expand Down Expand Up @@ -277,7 +290,7 @@ pub fn read(decoder: &mut BinDecoder<'_>, rdata_length: Restrict<u16>) -> ProtoR

/// Write the RData from the given Decoder
pub fn emit(encoder: &mut BinEncoder<'_>, opt: &OPT) -> ProtoResult<()> {
for (edns_code, edns_option) in opt.options().iter() {
for (edns_code, edns_option) in opt.as_ref().iter() {
encoder.emit_u16(u16::from(*edns_code))?;
encoder.emit_u16(edns_option.len())?;
edns_option.emit(encoder)?
Expand Down
58 changes: 58 additions & 0 deletions crates/proto/src/rr/resource.rs
Expand Up @@ -264,6 +264,64 @@ impl Record {
pub fn into_data(self) -> RData {
self.rdata
}

/// Consumes `Record` and returns its components
pub fn into_parts(self) -> RecordParts {
self.into()
}
}

/// Consumes `Record` giving public access to fields of `Record` so they can
/// be destructured and taken by value
pub struct RecordParts {
/// label names
pub name_labels: Name,
/// record type
pub rr_type: RecordType,
/// dns class
pub dns_class: DNSClass,
/// time to live
pub ttl: u32,
/// rdata
pub rdata: RData,
/// mDNS cache flush
#[cfg(feature = "mdns")]
pub mdns_cache_flush: bool,
}

impl From<Record> for RecordParts {
fn from(record: Record) -> Self {
cfg_if::cfg_if! {
if #[cfg(feature = "mdns")] {
let Record {
name_labels,
rr_type,
dns_class,
ttl,
rdata,
mdns_cache_flush,
} = record;
} else {
let Record {
name_labels,
rr_type,
dns_class,
ttl,
rdata,
} = record;
}
}

RecordParts {
name_labels,
rr_type,
dns_class,
ttl,
rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
}
}
}

#[allow(deprecated)]
Expand Down
41 changes: 41 additions & 0 deletions crates/proto/src/rr/rr_set.rs
Expand Up @@ -431,6 +431,47 @@ impl RecordSet {

removed
}

/// Consumes `RecordSet` and returns its components
pub fn into_parts(self) -> RecordSetParts {
self.into()
}
}

/// Consumes `RecordSet` giving public access to fields of `RecordSet` so they can
/// be destructured and taken by value
#[derive(Clone, Debug, PartialEq)]
pub struct RecordSetParts {
pub name: Name,
pub record_type: RecordType,
pub dns_class: DNSClass,
pub ttl: u32,
pub records: Vec<Record>,
pub rrsigs: Vec<Record>,
pub serial: u32, // serial number at which this record was modifie,
}

impl From<RecordSet> for RecordSetParts {
fn from(rset: RecordSet) -> Self {
let RecordSet {
name,
record_type,
dns_class,
ttl,
records,
rrsigs,
serial,
} = rset;
RecordSetParts {
name,
record_type,
dns_class,
ttl,
records,
rrsigs,
serial,
}
}
}

impl From<Record> for RecordSet {
Expand Down
2 changes: 1 addition & 1 deletion crates/proto/src/xfer/dns_request.rs
Expand Up @@ -12,7 +12,7 @@ use std::ops::{Deref, DerefMut};
use crate::op::Message;

/// A set of options for expressing options to how requests should be treated
#[derive(Clone, Default)]
#[derive(Clone, Copy, Debug, Default)]
pub struct DnsRequestOptions {
/// When true, the underlying DNS protocols will not return on the first response received.
///
Expand Down
2 changes: 1 addition & 1 deletion crates/proto/src/xfer/dns_response.rs
Expand Up @@ -443,7 +443,7 @@ impl From<SmallVec<[Message; 1]>> for DnsResponse {
/// NXT and SIG records MUST be added.
///
/// ```
#[derive(Eq, PartialEq, Debug)]
#[derive(Eq, PartialEq, Debug, Clone, Copy)]
pub enum NegativeType {
/// ```text
/// NXDOMAIN RESPONSE: TYPE 1.
Expand Down
2 changes: 1 addition & 1 deletion crates/resolver/src/caching_client.rs
Expand Up @@ -177,7 +177,7 @@ where

let response_message = client
.client
.lookup(query.clone(), options.clone())
.lookup(query.clone(), options)
.await
.map_err(E::into);

Expand Down

0 comments on commit a44d441

Please sign in to comment.