Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for impl Trait in associated type (type_alias_impl_trait) #153

Merged
merged 2 commits into from Mar 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 66 additions & 13 deletions src/expand.rs
Expand Up @@ -3,12 +3,13 @@ use crate::parse::Item;
use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::punctuated::Punctuated;
use syn::visit_mut::VisitMut;
use syn::visit_mut::{self, VisitMut};
use syn::{
parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
WhereClause,
TypePath, WhereClause,
};

macro_rules! parse_quote_spanned {
Expand All @@ -34,6 +35,7 @@ enum Context<'a> {
},
Impl {
impl_generics: &'a Generics,
associated_type_impl_traits: &'a Set<Ident>,
},
}

Expand Down Expand Up @@ -71,7 +73,7 @@ pub fn expand(input: &mut Item, is_local: bool) {
method.attrs.push(parse_quote!(#[must_use]));
if let Some(block) = block {
has_self |= has_self_in_block(block);
transform_block(sig, block);
transform_block(context, sig, block);
method.attrs.push(lint_suppress_with_body());
} else {
method.attrs.push(lint_suppress_without_body());
Expand All @@ -90,16 +92,26 @@ pub fn expand(input: &mut Item, is_local: bool) {
let elided = lifetimes.elided;
input.generics.params = parse_quote!(#(#elided,)* #params);

let mut associated_type_impl_traits = Set::new();
for inner in &input.items {
if let ImplItem::Type(assoc) = inner {
if let Type::ImplTrait(_) = assoc.ty {
associated_type_impl_traits.insert(assoc.ident.clone());
}
}
}

let context = Context::Impl {
impl_generics: &input.generics,
associated_type_impl_traits: &associated_type_impl_traits,
};
for inner in &mut input.items {
if let ImplItem::Method(method) = inner {
let sig = &mut method.sig;
if sig.asyncness.is_some() {
let block = &mut method.block;
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
transform_block(sig, block);
transform_block(context, sig, block);
transform_sig(context, sig, has_self, false, is_local);
method.attrs.push(lint_suppress_with_body());
}
Expand Down Expand Up @@ -296,7 +308,7 @@ fn transform_sig(
//
// ___ret
// })
fn transform_block(sig: &mut Signature, block: &mut Block) {
fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
if block.stmts.len() == 1 && item.to_string() == ";" {
return;
Expand Down Expand Up @@ -345,18 +357,24 @@ fn transform_block(sig: &mut Signature, block: &mut Block) {
}

let stmts = &block.stmts;
let let_ret = match &sig.output {
let let_ret = match &mut sig.output {
ReturnType::Default => quote_spanned! {block.brace_token.span=>
let _: () = { #(#decls)* #(#stmts)* };
},
ReturnType::Type(_, ret) => quote_spanned! {block.brace_token.span=>
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
return __ret;
ReturnType::Type(_, ret) => {
if contains_associated_type_impl_trait(context, ret) {
quote!(#(#decls)* #(#stmts)*)
} else {
quote_spanned! {block.brace_token.span=>
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
return __ret;
}
let __ret: #ret = { #(#decls)* #(#stmts)* };
#[allow(unreachable_code)]
__ret
}
}
let __ret: #ret = { #(#decls)* #(#stmts)* };
#[allow(unreachable_code)]
__ret
},
}
};
let box_pin = quote_spanned!(block.brace_token.span=>
Box::pin(async move { #let_ret })
Expand All @@ -380,6 +398,41 @@ fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
false
}

fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
struct AssociatedTypeImplTraits<'a> {
set: &'a Set<Ident>,
contains: bool,
}

impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
if ty.qself.is_none()
&& ty.path.segments.len() == 2
&& ty.path.segments[0].ident == "Self"
&& self.set.contains(&ty.path.segments[1].ident)
{
self.contains = true;
}
visit_mut::visit_type_path_mut(self, ty);
}
}

match context {
Context::Trait { .. } => false,
Context::Impl {
associated_type_impl_traits,
..
} => {
let mut visit = AssociatedTypeImplTraits {
set: associated_type_impl_traits,
contains: false,
};
visit.visit_type_mut(ret);
visit.contains
}
}
}

fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
clause.get_or_insert_with(|| WhereClause {
where_token: Default::default(),
Expand Down
24 changes: 23 additions & 1 deletion tests/test.rs
@@ -1,6 +1,6 @@
#![cfg_attr(
async_trait_nightly_testing,
feature(min_specialization, min_const_generics)
feature(min_specialization, min_const_generics, type_alias_impl_trait)
)]
#![allow(
clippy::let_underscore_drop,
Expand Down Expand Up @@ -1278,3 +1278,25 @@ pub mod issue149 {
}
}
}

// https://github.com/dtolnay/async-trait/issues/152
#[cfg(async_trait_nightly_testing)]
pub mod issue152 {
use async_trait::async_trait;

#[async_trait]
trait Trait {
type Assoc;

async fn f(&self) -> Self::Assoc;
}

struct Struct;

#[async_trait]
impl Trait for Struct {
type Assoc = impl Sized;

async fn f(&self) -> Self::Assoc {}
}
}