diff --git a/mockall/tests/mock_return_dyn_trait.rs b/mockall/tests/mock_return_dyn_trait.rs index 6011791c..ae6ba077 100644 --- a/mockall/tests/mock_return_dyn_trait.rs +++ b/mockall/tests/mock_return_dyn_trait.rs @@ -10,6 +10,23 @@ trait T: Debug + Sync { } impl T for u32 {} +trait Base { + fn as_t(&self) -> &dyn T; +} +trait Derived: Base { + fn as_base(&self) -> &dyn Base; +} + +impl Base for u32 { + fn as_t(&self) -> &dyn T { + self + } +} +impl Derived for u32 { + fn as_base(&self) -> &dyn Base { + self + } +} impl T for Q where Q: Debug + Sync + AsMut {} @@ -21,6 +38,16 @@ mock!{ } } +mock! { + Bar {} + impl Base for Bar { + fn as_t(&self) -> &dyn T; + } + impl Derived for Bar { + fn as_base(&self) -> &dyn Base; + } +} + #[test] fn return_const() { let mut mock = MockFoo::new(); @@ -47,3 +74,12 @@ fn return_var() { mock.baz().mutate(); } + +#[test] +fn dyn_upcast() { + let mut mock = MockBar::new(); + mock.expect_as_t() + .return_const(Box::new(42u32) as Box); + + assert_eq!("42", format!("{:?}", mock.as_t())); +} diff --git a/mockall_derive/src/mock_function.rs b/mockall_derive/src/mock_function.rs index b560c60a..e704d9ba 100644 --- a/mockall_derive/src/mock_function.rs +++ b/mockall_derive/src/mock_function.rs @@ -4,7 +4,7 @@ use super::*; use quote::ToTokens; /// Convert a trait object reference into a reference to a Boxed trait -fn dedynify(ty: &mut Type) { +fn dedynify(ty: &mut Type) -> bool { if let Type::Reference(ref mut tr) = ty { if let Type::TraitObject(ref tto) = tr.elem.as_ref() { if let Some(lt) = &tr.lifetime { @@ -17,13 +17,15 @@ fn dedynify(ty: &mut Type) { // when methods like returning add a `+ Send` to the output // type. *tr.elem = parse2(quote!((#tto))).unwrap(); - return; + return false; } } *tr.elem = parse2(quote!(Box<#tto>)).unwrap(); + return true; } } + false } /// Convert a special reference type like "&str" into a reference to its owned @@ -207,16 +209,19 @@ impl<'a> Builder<'a> { is_static = false; } } - let output = match self.sig.output { - ReturnType::Default => Type::Tuple(TypeTuple { + let (output, boxed) = match self.sig.output { + ReturnType::Default => ( + Type::Tuple(TypeTuple { paren_token: token::Paren::default(), - elems: Punctuated::new() + elems: Punctuated::new(), }), + false, + ), ReturnType::Type(_, ref ty) => { let mut output_ty = supersuperfy(&**ty, self.levels); destrify(&mut output_ty); - dedynify(&mut output_ty); - output_ty + let boxed = dedynify(&mut output_ty); + (output_ty, boxed) } }; supersuperfy_generics(&mut declosured_generics, self.levels); @@ -284,6 +289,7 @@ impl<'a> Builder<'a> { mod_ident: self.parent.unwrap_or(&Ident::new("FIXME", Span::call_site())).clone(), output, owned_output, + boxed, predexprs, predty, refpredty, @@ -389,6 +395,8 @@ pub(crate) struct MockFunction { /// If the real output type is a non-'static reference, then it will differ /// from this field. owned_output: Type, + /// True if the `owned_type` is boxed by `Box<>`. + boxed: bool, /// Expressions that create the predicate arguments from the call arguments predexprs: Vec, /// Types used for Predicates. Will be almost the same as args, but every @@ -452,6 +460,14 @@ impl MockFunction { } else { Ident::new("call", Span::call_site()) }; + let mut deref = quote!(); + if self.boxed { + if self.return_ref { + deref = quote! ( & * * ); + } else if self.return_refmut { + deref = quote ! ( & mut * * ); + } + } if self.is_static { let outer_mod_path = self.outer_mod_path(modname); quote!( @@ -459,7 +475,7 @@ impl MockFunction { #(#attrs)* #vis #sig { let no_match_msg = #no_match_msg; - { + #deref { let __mockall_guard = #outer_mod_path::EXPECTATIONS .lock().unwrap(); /* @@ -479,7 +495,7 @@ impl MockFunction { #(#attrs)* #vis #sig { let no_match_msg = #no_match_msg; - self.#substruct_obj #name.#call#tbf(#(#call_exprs,)*) + #deref self.#substruct_obj #name.#call#tbf(#(#call_exprs,)*) .expect(&no_match_msg) }