Skip to content

Commit

Permalink
Merge pull request #153 from dtolnay/impltrait
Browse files Browse the repository at this point in the history
Support for impl Trait in associated type (type_alias_impl_trait)
  • Loading branch information
dtolnay committed Mar 7, 2021
2 parents 6d3cf66 + cc9c90b commit c03a0c6
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 14 deletions.
79 changes: 66 additions & 13 deletions src/expand.rs
Expand Up @@ -3,12 +3,13 @@ use crate::parse::Item;
use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::punctuated::Punctuated;
use syn::visit_mut::VisitMut;
use syn::visit_mut::{self, VisitMut};
use syn::{
parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
WhereClause,
TypePath, WhereClause,
};

macro_rules! parse_quote_spanned {
Expand All @@ -34,6 +35,7 @@ enum Context<'a> {
},
Impl {
impl_generics: &'a Generics,
associated_type_impl_traits: &'a Set<Ident>,
},
}

Expand Down Expand Up @@ -71,7 +73,7 @@ pub fn expand(input: &mut Item, is_local: bool) {
method.attrs.push(parse_quote!(#[must_use]));
if let Some(block) = block {
has_self |= has_self_in_block(block);
transform_block(sig, block);
transform_block(context, sig, block);
method.attrs.push(lint_suppress_with_body());
} else {
method.attrs.push(lint_suppress_without_body());
Expand All @@ -90,16 +92,26 @@ pub fn expand(input: &mut Item, is_local: bool) {
let elided = lifetimes.elided;
input.generics.params = parse_quote!(#(#elided,)* #params);

let mut associated_type_impl_traits = Set::new();
for inner in &input.items {
if let ImplItem::Type(assoc) = inner {
if let Type::ImplTrait(_) = assoc.ty {
associated_type_impl_traits.insert(assoc.ident.clone());
}
}
}

let context = Context::Impl {
impl_generics: &input.generics,
associated_type_impl_traits: &associated_type_impl_traits,
};
for inner in &mut input.items {
if let ImplItem::Method(method) = inner {
let sig = &mut method.sig;
if sig.asyncness.is_some() {
let block = &mut method.block;
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
transform_block(sig, block);
transform_block(context, sig, block);
transform_sig(context, sig, has_self, false, is_local);
method.attrs.push(lint_suppress_with_body());
}
Expand Down Expand Up @@ -296,7 +308,7 @@ fn transform_sig(
//
// ___ret
// })
fn transform_block(sig: &mut Signature, block: &mut Block) {
fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
if block.stmts.len() == 1 && item.to_string() == ";" {
return;
Expand Down Expand Up @@ -345,18 +357,24 @@ fn transform_block(sig: &mut Signature, block: &mut Block) {
}

let stmts = &block.stmts;
let let_ret = match &sig.output {
let let_ret = match &mut sig.output {
ReturnType::Default => quote_spanned! {block.brace_token.span=>
let _: () = { #(#decls)* #(#stmts)* };
},
ReturnType::Type(_, ret) => quote_spanned! {block.brace_token.span=>
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
return __ret;
ReturnType::Type(_, ret) => {
if contains_associated_type_impl_trait(context, ret) {
quote!(#(#decls)* #(#stmts)*)
} else {
quote_spanned! {block.brace_token.span=>
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
return __ret;
}
let __ret: #ret = { #(#decls)* #(#stmts)* };
#[allow(unreachable_code)]
__ret
}
}
let __ret: #ret = { #(#decls)* #(#stmts)* };
#[allow(unreachable_code)]
__ret
},
}
};
let box_pin = quote_spanned!(block.brace_token.span=>
Box::pin(async move { #let_ret })
Expand All @@ -380,6 +398,41 @@ fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
false
}

fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
struct AssociatedTypeImplTraits<'a> {
set: &'a Set<Ident>,
contains: bool,
}

impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
if ty.qself.is_none()
&& ty.path.segments.len() == 2
&& ty.path.segments[0].ident == "Self"
&& self.set.contains(&ty.path.segments[1].ident)
{
self.contains = true;
}
visit_mut::visit_type_path_mut(self, ty);
}
}

match context {
Context::Trait { .. } => false,
Context::Impl {
associated_type_impl_traits,
..
} => {
let mut visit = AssociatedTypeImplTraits {
set: associated_type_impl_traits,
contains: false,
};
visit.visit_type_mut(ret);
visit.contains
}
}
}

fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
clause.get_or_insert_with(|| WhereClause {
where_token: Default::default(),
Expand Down
24 changes: 23 additions & 1 deletion tests/test.rs
@@ -1,6 +1,6 @@
#![cfg_attr(
async_trait_nightly_testing,
feature(min_specialization, min_const_generics)
feature(min_specialization, min_const_generics, type_alias_impl_trait)
)]
#![allow(
clippy::let_underscore_drop,
Expand Down Expand Up @@ -1278,3 +1278,25 @@ pub mod issue149 {
}
}
}

// https://github.com/dtolnay/async-trait/issues/152
#[cfg(async_trait_nightly_testing)]
pub mod issue152 {
use async_trait::async_trait;

#[async_trait]
trait Trait {
type Assoc;

async fn f(&self) -> Self::Assoc;
}

struct Struct;

#[async_trait]
impl Trait for Struct {
type Assoc = impl Sized;

async fn f(&self) -> Self::Assoc {}
}
}

0 comments on commit c03a0c6

Please sign in to comment.