diff --git a/src/expand.rs b/src/expand.rs index 06c4ba3..6cc6ef0 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -3,12 +3,13 @@ use crate::parse::Item; use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf}; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::BTreeSet as Set; use syn::punctuated::Punctuated; -use syn::visit_mut::VisitMut; +use syn::visit_mut::{self, VisitMut}; use syn::{ parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat, PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound, - WhereClause, + TypePath, WhereClause, }; macro_rules! parse_quote_spanned { @@ -34,6 +35,7 @@ enum Context<'a> { }, Impl { impl_generics: &'a Generics, + associated_type_impl_traits: &'a Set, }, } @@ -71,7 +73,7 @@ pub fn expand(input: &mut Item, is_local: bool) { method.attrs.push(parse_quote!(#[must_use])); if let Some(block) = block { has_self |= has_self_in_block(block); - transform_block(sig, block); + transform_block(context, sig, block); method.attrs.push(lint_suppress_with_body()); } else { method.attrs.push(lint_suppress_without_body()); @@ -90,8 +92,18 @@ pub fn expand(input: &mut Item, is_local: bool) { let elided = lifetimes.elided; input.generics.params = parse_quote!(#(#elided,)* #params); + let mut associated_type_impl_traits = Set::new(); + for inner in &input.items { + if let ImplItem::Type(assoc) = inner { + if let Type::ImplTrait(_) = assoc.ty { + associated_type_impl_traits.insert(assoc.ident.clone()); + } + } + } + let context = Context::Impl { impl_generics: &input.generics, + associated_type_impl_traits: &associated_type_impl_traits, }; for inner in &mut input.items { if let ImplItem::Method(method) = inner { @@ -99,7 +111,7 @@ pub fn expand(input: &mut Item, is_local: bool) { if sig.asyncness.is_some() { let block = &mut method.block; let has_self = has_self_in_sig(sig) || has_self_in_block(block); - transform_block(sig, block); + transform_block(context, sig, block); transform_sig(context, sig, has_self, false, is_local); method.attrs.push(lint_suppress_with_body()); } @@ -296,7 +308,7 @@ fn transform_sig( // // ___ret // }) -fn transform_block(sig: &mut Signature, block: &mut Block) { +fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) { if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() { if block.stmts.len() == 1 && item.to_string() == ";" { return; @@ -345,18 +357,24 @@ fn transform_block(sig: &mut Signature, block: &mut Block) { } let stmts = &block.stmts; - let let_ret = match &sig.output { + let let_ret = match &mut sig.output { ReturnType::Default => quote_spanned! {block.brace_token.span=> let _: () = { #(#decls)* #(#stmts)* }; }, - ReturnType::Type(_, ret) => quote_spanned! {block.brace_token.span=> - if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> { - return __ret; + ReturnType::Type(_, ret) => { + if contains_associated_type_impl_trait(context, ret) { + quote!(#(#decls)* #(#stmts)*) + } else { + quote_spanned! {block.brace_token.span=> + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> { + return __ret; + } + let __ret: #ret = { #(#decls)* #(#stmts)* }; + #[allow(unreachable_code)] + __ret + } } - let __ret: #ret = { #(#decls)* #(#stmts)* }; - #[allow(unreachable_code)] - __ret - }, + } }; let box_pin = quote_spanned!(block.brace_token.span=> Box::pin(async move { #let_ret }) @@ -380,6 +398,41 @@ fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool { false } +fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool { + struct AssociatedTypeImplTraits<'a> { + set: &'a Set, + contains: bool, + } + + impl<'a> VisitMut for AssociatedTypeImplTraits<'a> { + fn visit_type_path_mut(&mut self, ty: &mut TypePath) { + if ty.qself.is_none() + && ty.path.segments.len() == 2 + && ty.path.segments[0].ident == "Self" + && self.set.contains(&ty.path.segments[1].ident) + { + self.contains = true; + } + visit_mut::visit_type_path_mut(self, ty); + } + } + + match context { + Context::Trait { .. } => false, + Context::Impl { + associated_type_impl_traits, + .. + } => { + let mut visit = AssociatedTypeImplTraits { + set: associated_type_impl_traits, + contains: false, + }; + visit.visit_type_mut(ret); + visit.contains + } + } +} + fn where_clause_or_default(clause: &mut Option) -> &mut WhereClause { clause.get_or_insert_with(|| WhereClause { where_token: Default::default(), diff --git a/tests/test.rs b/tests/test.rs index 06c0918..6f74bee 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,6 @@ #![cfg_attr( async_trait_nightly_testing, - feature(min_specialization, min_const_generics) + feature(min_specialization, min_const_generics, type_alias_impl_trait) )] #![allow( clippy::let_underscore_drop, @@ -1278,3 +1278,25 @@ pub mod issue149 { } } } + +// https://github.com/dtolnay/async-trait/issues/152 +#[cfg(async_trait_nightly_testing)] +pub mod issue152 { + use async_trait::async_trait; + + #[async_trait] + trait Trait { + type Assoc; + + async fn f(&self) -> Self::Assoc; + } + + struct Struct; + + #[async_trait] + impl Trait for Struct { + type Assoc = impl Sized; + + async fn f(&self) -> Self::Assoc {} + } +}