Skip to content

Commit

Permalink
der_derive: fix #[asn1(type = "...")] attribute (#204)
Browse files Browse the repository at this point in the history
PR #202 removed support for the `#[asn1(type = "...")]` attribute when
deriving `Sequence`.

This PR restores it, and also slightly changes the way the `default`
attribute is handled, with added tests.

Additionally, this commit introduces a reference type for `OPTIONAL`
similar to the one for `CONTEXT-SPECIFIC` called `OptionalRef` that can
be used as an `Encodable` trait object for `Option<&T>`. This is used by
the custom derive when encoding fields with defaults.
  • Loading branch information
tarcieri committed Nov 11, 2021
1 parent ce68844 commit a036a4c
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 22 deletions.
23 changes: 10 additions & 13 deletions der/derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl TypeAttrs {
// `tag_mode = "..."` attribute
if let Some(mode) = attr.parse_value("tag_mode") {
if tag_mode.is_some() {
abort!(attr.value, "duplicate ASN.1 `tag_mode` attribute");
abort!(attr.name, "duplicate ASN.1 `tag_mode` attribute");
}

tag_mode = Some(mode);
Expand Down Expand Up @@ -79,9 +79,7 @@ impl FieldAttrs {
pub fn parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> Self {
let mut asn1_type = None;
let mut context_specific = None;

let mut tag_mode = None;

let mut default = None;

let mut parsed_attrs = Vec::new();
Expand All @@ -98,22 +96,24 @@ impl FieldAttrs {
// `type = "..."` attribute
} else if let Some(ty) = attr.parse_value("type") {
if asn1_type.is_some() {
abort!(attr.value, "duplicate ASN.1 `type` attribute: {}");
abort!(attr.name, "duplicate ASN.1 `type` attribute: {}");
}

asn1_type = Some(ty);
} else if let Some(mode) = attr.parse_value("tag_mode") {
if tag_mode.is_some() {
abort!(attr.value, "duplicate ASN.1 `tag_mode` attribute");
abort!(attr.name, "duplicate ASN.1 `tag_mode` attribute");
}

tag_mode = Some(mode);
} else if attr.parse_value::<String>("default").is_some() {
if default.is_some() {
abort!(attr.value, "duplicate ASN.1 `default` attribute");
abort!(attr.name, "duplicate ASN.1 `default` attribute");
}

default = Some(attr.lit_str.parse::<Path>().unwrap());
default = Some(attr.value.parse().unwrap_or_else(|e| {
abort!(attr.value, "error parsing ASN.1 `default` attribute: {}", e)
}));
} else {
abort!(
attr.name,
Expand Down Expand Up @@ -204,10 +204,7 @@ struct AttrNameValue {
pub name: Path,

/// Attribute value.
pub value: String,

/// Attribute value.
pub lit_str: LitStr,
pub value: LitStr,
}

impl AttrNameValue {
Expand All @@ -231,8 +228,7 @@ impl AttrNameValue {
..
})) => out.push(Self {
name: path.clone(),
value: lit_str.value(),
lit_str: lit_str.clone(),
value: lit_str.clone(),
}),
_ => abort!(nested, "malformed `asn1` attribute"),
}
Expand All @@ -249,6 +245,7 @@ impl AttrNameValue {
if self.name.is_ident(name) {
Some(
self.value
.value()
.parse()
.unwrap_or_else(|_| abort!(self.name, "error parsing `{}` attribute")),
)
Expand Down
1 change: 1 addition & 0 deletions der/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
//! [`der::asn1::Utf8String`]: https://docs.rs/der/latest/der/asn1/struct.Utf8String.html

#![crate_type = "proc-macro"]
#![forbid(unsafe_code, clippy::unwrap_used)]
#![warn(rust_2018_idioms, trivial_casts, unused_qualifications)]

mod asn1_type;
Expand Down
48 changes: 41 additions & 7 deletions der/derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ impl SequenceField {
abort!(ident, "IMPLICIT tagging not supported for `Sequence`");
}

if attrs.asn1_type.is_some() && attrs.default.is_some() {
abort!(
ident,
"ASN.1 `type` and `default` options cannot be combined"
);
}

Self {
ident,
attrs,
Expand All @@ -145,22 +152,49 @@ impl SequenceField {

/// Derive code for decoding a field of a sequence.
fn to_decode_tokens(&self) -> TokenStream {
// NOTE: `type` and `default` are mutually exclusive
// This is checked above in `SequenceField::new`
debug_assert!(self.attrs.asn1_type.is_none() || self.attrs.default.is_none());

let ident = &self.ident;
let ty = self.field_type.clone();
if let Some(default) = &self.attrs.default {
quote!(let mut #ident = Some(decoder.decode::<#ty>()?.unwrap_or_else(#default));)
let ty = &self.field_type;

if self.attrs.asn1_type.is_some() {
let dec = self.attrs.decoder();
quote! {
let #ident = #dec.try_into()?;
}
} else if let Some(default) = &self.attrs.default {
quote! {
let #ident = decoder.decode::<Option<#ty>>()?.unwrap_or_else(#default);
}
} else {
quote!(let #ident = decoder.decode()?;)
quote! {
let #ident = decoder.decode()?;
}
}
}

/// Derive code for encoding a field of a sequence.
fn to_encode_tokens(&self) -> TokenStream {
// NOTE: `type` and `default` are mutually exclusive
// This is checked above in `SequenceField::new`
debug_assert!(self.attrs.asn1_type.is_none() || self.attrs.default.is_none());

let ident = &self.ident;
let binding = quote!(&self.#ident);
let binding_noref = quote!(self.#ident);
if let Some(default) = &self.attrs.default {
quote!(&if #binding_noref == Some(#default()) {None} else {Some(#binding_noref)})

if let Some(ty) = &self.attrs.asn1_type {
let encoder = ty.encoder(&binding);
quote!(&#encoder?)
} else if let Some(default) = &self.attrs.default {
quote! {
&::der::asn1::OptionalRef(if #binding == &#default() {
None
} else {
Some(#binding)
})
}
} else {
quote!(#binding)
}
Expand Down
1 change: 1 addition & 0 deletions der/src/asn1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub use self::{
integer::bigint::UIntBytes,
null::Null,
octet_string::OctetString,
optional::OptionalRef,
printable_string::PrintableString,
sequence::Sequence,
sequence_of::{SequenceOf, SequenceOfIter},
Expand Down
24 changes: 24 additions & 0 deletions der/src/asn1/optional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,27 @@ where
}
}
}

/// A reference to an ASN.1 `OPTIONAL` type, used for encoding only.
pub struct OptionalRef<'a, T>(pub Option<&'a T>);

impl<'a, T> Encodable for OptionalRef<'a, T>
where
T: Encodable,
{
fn encoded_len(&self) -> Result<Length> {
if let Some(encodable) = self.0 {
encodable.encoded_len()
} else {
Ok(0u8.into())
}
}

fn encode(&self, encoder: &mut Encoder<'_>) -> Result<()> {
if let Some(encodable) = self.0 {
encodable.encode(encoder)
} else {
Ok(())
}
}
}
20 changes: 18 additions & 2 deletions der/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ mod enumerated {
/// Custom derive test cases for the `Sequence` macro.
mod sequence {
use der::{
asn1::{Any, BitString, ObjectIdentifier},
asn1::{Any, ObjectIdentifier},
Decodable, Encodable, Sequence,
};
use hex_literal::hex;
Expand All @@ -217,8 +217,24 @@ mod sequence {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Sequence)]
pub struct SubjectPublicKeyInfo<'a> {
pub algorithm: AlgorithmIdentifier<'a>,
#[asn1(type = "BIT STRING")]
pub subject_public_key: &'a [u8],
}

/// X.509 extension
// TODO(tarcieri): tests for code derived with the `default` attribute
#[derive(Clone, Debug, Eq, PartialEq, Sequence)]
pub struct Extension<'a> {
extn_id: ObjectIdentifier,
#[asn1(default = "critical_default")]
critical: bool,
#[asn1(type = "OCTET STRING")]
extn_value: &'a [u8],
}

pub subject_public_key: BitString<'a>,
/// Default value of the `critical` bit
fn critical_default() -> bool {
false
}

const ID_EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new("1.2.840.10045.2.1");
Expand Down

0 comments on commit a036a4c

Please sign in to comment.