diff --git a/snafu-derive/src/lib.rs b/snafu-derive/src/lib.rs index d12733af..f89e9659 100644 --- a/snafu-derive/src/lib.rs +++ b/snafu-derive/src/lib.rs @@ -194,15 +194,28 @@ impl SourceField { } enum Transformation { - None { ty: syn::Type }, - Transform { ty: syn::Type, expr: syn::Expr }, + None { + ty: syn::Type, + }, + Transform { + source_ty: syn::Type, + target_ty: syn::Type, + expr: syn::Expr, + }, } impl Transformation { - fn ty(&self) -> &syn::Type { + fn source_ty(&self) -> &syn::Type { + match self { + Transformation::None { ty } => ty, + Transformation::Transform { source_ty, .. } => source_ty, + } + } + + fn target_ty(&self) -> &syn::Type { match self { Transformation::None { ty } => ty, - Transformation::Transform { ty, .. } => ty, + Transformation::Transform { target_ty, .. } => target_ty, } } @@ -936,7 +949,11 @@ fn field_container( name, ty, provide, .. } = field; let transformation = maybe_transformation - .map(|(ty, expr)| Transformation::Transform { ty, expr }) + .map(|(source_ty, expr)| Transformation::Transform { + source_ty, + target_ty: ty.clone(), + expr, + }) .unwrap_or_else(|| Transformation::None { ty }); source_fields.add( @@ -1237,12 +1254,15 @@ fn parse_snafu_tuple_struct( return Err(vec![one_field_error(span)]); } + let ty = inner.into_value().ty; let (maybe_transformation, errs) = transformations.finish(); let transformation = maybe_transformation - .map(|(ty, expr)| Transformation::Transform { ty, expr }) - .unwrap_or_else(|| Transformation::None { - ty: inner.into_value().ty, - }); + .map(|(source_ty, expr)| Transformation::Transform { + source_ty, + target_ty: ty.clone(), + expr, + }) + .unwrap_or_else(|| Transformation::None { ty }); errors.extend(errs); let (maybe_crate_root, errs) = crate_roots.finish(); @@ -1878,7 +1898,7 @@ impl TupleStructInfo { provides, } = self; - let inner_type = transformation.ty(); + let inner_type = transformation.source_ty(); let transformation = transformation.transformation(); let where_clauses: Vec<_> = generics diff --git a/snafu-derive/src/shared.rs b/snafu-derive/src/shared.rs index 8580ea65..7d7c342f 100644 --- a/snafu-derive/src/shared.rs +++ b/snafu-derive/src/shared.rs @@ -314,12 +314,20 @@ pub mod context_selector { self.construct_implicit_fields() }; - let (source_ty, transfer_source_field) = match source_field { + let (source_ty, transform_source, transfer_source_field) = match source_field { Some(source_field) => { - let (ty, transfer) = build_source_info(source_field); - (quote! { #ty }, transfer) + let SourceInfo { + source_field_type, + transform_source, + transfer_source_field, + } = build_source_info(source_field); + ( + quote! { #source_field_type }, + Some(transform_source), + Some(transfer_source_field), + ) } - None => (quote! { #crate_root::NoneError }, quote! {}), + None => (quote! { #crate_root::NoneError }, None, None), }; let track_caller = track_caller(); @@ -334,6 +342,7 @@ pub mod context_selector { #track_caller fn into_error(self, error: Self::Source) -> #parameterized_error_name { + #transform_source; #error_constructor_name { #construct_implicit_fields #transfer_source_field @@ -360,7 +369,7 @@ pub mod context_selector { let (source_ty, transfer_source_field, empty_source_field) = match source_field { Some(f) => { - let source_field_type = f.transformation.ty(); + let source_field_type = f.transformation.source_ty(); let source_field_name = &f.name; let source_transformation = f.transformation.transformation(); @@ -411,7 +420,11 @@ pub mod context_selector { let user_field_generics = self.user_field_generics(); let where_clauses = self.where_clauses; - let (source_field_type, transfer_source_field) = build_source_info(source_field); + let SourceInfo { + source_field_type, + transform_source, + transfer_source_field, + } = build_source_info(source_field); let track_caller = track_caller(); @@ -422,6 +435,7 @@ pub mod context_selector { { #track_caller fn from(error: #source_field_type) -> Self { + #transform_source; #error_constructor_name { #construct_implicit_fields_with_source #transfer_source_field @@ -432,16 +446,28 @@ pub mod context_selector { } } + struct SourceInfo<'a> { + source_field_type: &'a syn::Type, + transform_source: TokenStream, + transfer_source_field: TokenStream, + } + // Assumes that the error is in a variable called "error" - fn build_source_info(source_field: &crate::SourceField) -> (&syn::Type, TokenStream) { + fn build_source_info(source_field: &crate::SourceField) -> SourceInfo<'_> { let source_field_name = source_field.name(); - let source_field_type = source_field.transformation.ty(); + let source_field_type = source_field.transformation.source_ty(); + let target_field_type = source_field.transformation.target_ty(); let source_transformation = source_field.transformation.transformation(); - ( + let transform_source = + quote! { let error: #target_field_type = (#source_transformation)(error) }; + let transfer_source_field = quote! { #source_field_name: error, }; + + SourceInfo { source_field_type, - quote! { #source_field_name: (#source_transformation)(error), }, - ) + transform_source, + transfer_source_field, + } } fn track_caller() -> proc_macro2::TokenStream { @@ -750,7 +776,8 @@ pub mod error { .source_field() .filter(|f| f.provide); - let source_provide_ref = provided_source.map(|f| (f.transformation.ty(), f.name())); + let source_provide_ref = + provided_source.map(|f| (f.transformation.source_ty(), f.name())); let provide_refs = provide_refs.chain(source_provide_ref); diff --git a/tests/implicit.rs b/tests/implicit.rs index 41ed519f..4c7737b8 100644 --- a/tests/implicit.rs +++ b/tests/implicit.rs @@ -157,3 +157,34 @@ mod with_and_without_source { assert_eq!(e.data.0, ItWas::GenerateWithSource); } } + +mod converted_sources { + use snafu::{prelude::*, IntoError}; + + #[derive(Debug)] + struct ImplicitData; + + impl snafu::GenerateImplicitData for ImplicitData { + fn generate() -> Self { + Self + } + } + + #[derive(Debug, Snafu)] + struct HasSource { + backtrace: snafu::Backtrace, + + #[snafu(implicit)] + data: ImplicitData, + + #[snafu(source(from(String, Into::into)))] + source: Box, + } + + #[test] + fn receives_the_error_after_conversion() { + let e = HasSourceSnafu.into_error(String::from("bad")); + // Mostly testing that this compiles; assertion is bonus + assert_eq!(e.source.to_string(), "bad"); + } +}