Skip to content

Commit

Permalink
[mockall-derive] Fix asomers#363 (asomers#380)
Browse files Browse the repository at this point in the history
* fix `Box<dyn A>: A` required in returning `&dyn A` for asomers#363

* Tweak trait object return test

* Update changelog

* Add return value description for dedynify
  • Loading branch information
Frank King authored and onalante-msft committed May 7, 2022
1 parent 1dc7781 commit eb956e0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 30 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
receiver. For example, `PartialEq::eq` has a signature like
`fn eq(&self, other: &Self) -> bool`
([#373](https://github.com/asomers/mockall/pull/373))
- Fixed mocking methods that return a reference to a `dyn T` trait object,
when that trait is not already implemented for `Box<dyn T>`.
([#380](https://github.com/asomers/mockall/pull/380))

## [ 0.11.0 ] - 2021-12-11

Expand Down
50 changes: 29 additions & 21 deletions mockall/tests/mock_return_dyn_trait.rs
Expand Up @@ -3,47 +3,55 @@
#![deny(warnings)]

use mockall::*;
use std::fmt::Debug;

trait T: Debug + Sync {
fn mutate(&mut self) {}
trait Test: Sync {
fn value(&self) -> i32;
fn mutate(&mut self);
}

impl T for u32 {}
impl Test for i32 {
fn value(&self) -> i32 {
*self
}

impl<Q> T for Q where Q: Debug + Sync + AsMut<dyn T> {}
fn mutate(&mut self) {
*self = 0;
}
}

mock!{
mock! {
Foo {
fn foo(&self) -> &dyn Debug;
fn bar(&self) -> &'static dyn T;
fn baz(&mut self) -> &mut dyn T;
fn ref_dyn(&self) -> &dyn Test;
fn static_dyn(&self) -> &'static dyn Test;
fn mut_dyn(&mut self) -> &mut dyn Test;
}
}

#[test]
fn return_const() {
fn ref_dyn() {
let mut mock = MockFoo::new();
mock.expect_foo()
.return_const(Box::new(42u32) as Box<dyn Debug>);
mock.expect_ref_dyn()
.return_const(Box::new(42) as Box<dyn Test>);

assert_eq!("42", format!("{:?}", mock.foo()));
assert_eq!(42, mock.ref_dyn().value());
}

#[test]
fn static_ref() {
fn static_dyn() {
let mut mock = MockFoo::new();
mock.expect_bar()
.return_const(&42u32 as &dyn T);
mock.expect_static_dyn()
.return_const(&42 as &'static dyn Test);

assert_eq!("42", format!("{:?}", mock.bar()));
assert_eq!(42, mock.static_dyn().value());
}

#[test]
fn return_var() {
fn mut_dyn() {
let mut mock = MockFoo::new();
mock.expect_baz()
.return_var(Box::new(42u32) as Box<dyn T>);
mock.expect_mut_dyn()
.return_var(Box::new(42) as Box<dyn Test>);

mock.baz().mutate();
assert_eq!(42, mock.mut_dyn().value());
mock.mut_dyn().mutate();
assert_eq!(0, mock.mut_dyn().value());
}
38 changes: 29 additions & 9 deletions mockall_derive/src/mock_function.rs
Expand Up @@ -4,7 +4,11 @@ use super::*;
use quote::ToTokens;

/// Convert a trait object reference into a reference to a Boxed trait
fn dedynify(ty: &mut Type) {
///
/// # Returns
///
/// Returns `true` if it was necessary to box the type.
fn dedynify(ty: &mut Type) -> bool {
if let Type::Reference(ref mut tr) = ty {
if let Type::TraitObject(ref tto) = tr.elem.as_ref() {
if let Some(lt) = &tr.lifetime {
Expand All @@ -17,13 +21,15 @@ fn dedynify(ty: &mut Type) {
// when methods like returning add a `+ Send` to the output
// type.
*tr.elem = parse2(quote!((#tto))).unwrap();
return;
return false;
}
}

*tr.elem = parse2(quote!(Box<#tto>)).unwrap();
return true;
}
}
false
}

/// Convert a special reference type like "&str" into a reference to its owned
Expand Down Expand Up @@ -207,16 +213,19 @@ impl<'a> Builder<'a> {
is_static = false;
}
}
let output = match self.sig.output {
ReturnType::Default => Type::Tuple(TypeTuple {
let (output, boxed) = match self.sig.output {
ReturnType::Default => (
Type::Tuple(TypeTuple {
paren_token: token::Paren::default(),
elems: Punctuated::new()
elems: Punctuated::new(),
}),
false,
),
ReturnType::Type(_, ref ty) => {
let mut output_ty = supersuperfy(&**ty, self.levels);
destrify(&mut output_ty);
dedynify(&mut output_ty);
output_ty
let boxed = dedynify(&mut output_ty);
(output_ty, boxed)
}
};
supersuperfy_generics(&mut declosured_generics, self.levels);
Expand Down Expand Up @@ -284,6 +293,7 @@ impl<'a> Builder<'a> {
mod_ident: self.parent.unwrap_or(&Ident::new("FIXME", Span::call_site())).clone(),
output,
owned_output,
boxed,
predexprs,
predty,
refpredty,
Expand Down Expand Up @@ -389,6 +399,8 @@ pub(crate) struct MockFunction {
/// If the real output type is a non-'static reference, then it will differ
/// from this field.
owned_output: Type,
/// True if the `owned_type` is boxed by `Box<>`.
boxed: bool,
/// Expressions that create the predicate arguments from the call arguments
predexprs: Vec<TokenStream>,
/// Types used for Predicates. Will be almost the same as args, but every
Expand Down Expand Up @@ -452,14 +464,22 @@ impl MockFunction {
} else {
Ident::new("call", Span::call_site())
};
let mut deref = quote!();
if self.boxed {
if self.return_ref {
deref = quote!(&**);
} else if self.return_refmut {
deref = quote!(&mut **);
}
}
if self.is_static {
let outer_mod_path = self.outer_mod_path(modname);
quote!(
// Don't add a doc string. The original is included in #attrs
#(#attrs)*
#vis #sig {
let no_match_msg = #no_match_msg;
{
#deref {
let __mockall_guard = #outer_mod_path::EXPECTATIONS
.lock().unwrap();
/*
Expand All @@ -479,7 +499,7 @@ impl MockFunction {
#(#attrs)*
#vis #sig {
let no_match_msg = #no_match_msg;
self.#substruct_obj #name.#call#tbf(#(#call_exprs,)*)
#deref self.#substruct_obj #name.#call#tbf(#(#call_exprs,)*)
.expect(&no_match_msg)
}

Expand Down

0 comments on commit eb956e0

Please sign in to comment.