Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support #[from] on an Option field #147

Merged
merged 2 commits into from Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 32 additions & 9 deletions impl/src/expand.rs
Expand Up @@ -2,7 +2,7 @@ use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
use syn::{Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand Down Expand Up @@ -131,7 +131,7 @@ fn impl_struct(input: Struct) -> TokenStream {

let from_impl = input.from_field().map(|from_field| {
let backtrace_field = input.distinct_backtrace_field();
let from = from_field.ty;
let from = unoptional_type(from_field.ty);
let body = from_initializer(from_field, backtrace_field);
quote! {
#[allow(unused_qualifications)]
Expand Down Expand Up @@ -351,7 +351,7 @@ fn impl_enum(input: Enum) -> TokenStream {
let from_field = variant.from_field()?;
let backtrace_field = variant.distinct_backtrace_field();
let variant = &variant.ident;
let from = from_field.ty;
let from = unoptional_type(from_field.ty);
let body = from_initializer(from_field, backtrace_field);
Some(quote! {
#[allow(unused_qualifications)]
Expand Down Expand Up @@ -394,6 +394,11 @@ fn fields_pat(fields: &[Field]) -> TokenStream {

fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
let from_member = &from_field.member;
let some_source = if type_is_option(from_field.ty) {
quote!(std::option::Option::Some(source))
} else {
quote!(source)
};
let backtrace = backtrace_field.map(|backtrace_field| {
let backtrace_member = &backtrace_field.member;
if type_is_option(backtrace_field.ty) {
Expand All @@ -407,25 +412,43 @@ fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> Toke
}
});
quote!({
#from_member: source,
#from_member: #some_source,
#backtrace
})
}

fn type_is_option(ty: &Type) -> bool {
type_parameter_of_option(ty).is_some()
}

fn unoptional_type(ty: &Type) -> TokenStream {
let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
quote!(#unoptional)
}

fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
let path = match ty {
Type::Path(ty) => &ty.path,
_ => return false,
_ => return None,
};

let last = path.segments.last().unwrap();
if last.ident != "Option" {
return false;
return None;
}

let bracketed = match &last.arguments {
PathArguments::AngleBracketed(bracketed) => bracketed,
_ => return None,
};

if bracketed.args.len() != 1 {
return None;
}

match &last.arguments {
PathArguments::AngleBracketed(bracketed) => bracketed.args.len() == 1,
_ => false,
match &bracketed.args[0] {
GenericArgument::Type(arg) => Some(arg),
_ => None,
}
}

Expand Down
23 changes: 23 additions & 0 deletions tests/test_from.rs
Expand Up @@ -14,10 +14,21 @@ pub struct ErrorStruct {
source: io::Error,
}

#[derive(Error, Debug)]
#[error("...")]
pub struct ErrorStructOptional {
#[from]
source: Option<io::Error>,
}

#[derive(Error, Debug)]
#[error("...")]
pub struct ErrorTuple(#[from] io::Error);

#[derive(Error, Debug)]
#[error("...")]
pub struct ErrorTupleOptional(#[from] Option<io::Error>);

#[derive(Error, Debug)]
#[error("...")]
pub enum ErrorEnum {
Expand All @@ -27,6 +38,15 @@ pub enum ErrorEnum {
},
}

#[derive(Error, Debug)]
#[error("...")]
pub enum ErrorEnumOptional {
Test {
#[from]
source: Option<io::Error>,
},
}

#[derive(Error, Debug)]
#[error("...")]
pub enum Many {
Expand All @@ -39,7 +59,10 @@ fn assert_impl<T: From<io::Error>>() {}
#[test]
fn test_from() {
assert_impl::<ErrorStruct>();
assert_impl::<ErrorStructOptional>();
assert_impl::<ErrorTuple>();
assert_impl::<ErrorTupleOptional>();
assert_impl::<ErrorEnum>();
assert_impl::<ErrorEnumOptional>();
assert_impl::<Many>();
}