diff --git a/src/expand.rs b/src/expand.rs index fbf9aba..53918cb 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -9,9 +9,9 @@ use std::mem; use syn::punctuated::Punctuated; use syn::visit_mut::{self, VisitMut}; use syn::{ - parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericParam, Generics, Ident, - ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, Receiver, ReturnType, Signature, Stmt, Token, - TraitItem, Type, TypePath, WhereClause, + parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam, + Generics, Ident, ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, PathArguments, Receiver, + ReturnType, Signature, Stmt, Token, TraitItem, Type, TypePath, WhereClause, }; impl ToTokens for Item { @@ -229,23 +229,46 @@ fn transform_sig( .push(parse_quote_spanned!(default_span=> 'async_trait)); if has_self { - let bounds = match sig.inputs.iter().next() { + let bounds: &[InferredBound] = match sig.inputs.iter().next() { Some(FnArg::Receiver(Receiver { reference: Some(_), mutability: None, .. - })) => [InferredBound::Sync], + })) => &[InferredBound::Sync], Some(FnArg::Typed(arg)) - if match (arg.pat.as_ref(), arg.ty.as_ref()) { - (Pat::Ident(pat), Type::Reference(ty)) => { - pat.ident == "self" && ty.mutability.is_none() - } + if match arg.pat.as_ref() { + Pat::Ident(pat) => pat.ident == "self", _ => false, } => { - [InferredBound::Sync] + match arg.ty.as_ref() { + // self: &Self + Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync], + // self: Arc + Type::Path(ty) + if { + let segment = ty.path.segments.last().unwrap(); + segment.ident == "Arc" + && match &segment.arguments { + PathArguments::AngleBracketed(arguments) => { + arguments.args.len() == 1 + && match &arguments.args[0] { + GenericArgument::Type(Type::Path(arg)) => { + arg.path.is_ident("Self") + } + _ => false, + } + } + _ => false, + } + } => + { + &[InferredBound::Sync, InferredBound::Send] + } + _ => &[InferredBound::Send], + } } - _ => [InferredBound::Send], + _ => &[InferredBound::Send], }; let bounds = bounds.iter().filter_map(|bound| { diff --git a/tests/test.rs b/tests/test.rs index 7c8b751..23d8f80 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1450,3 +1450,14 @@ pub mod issue204 { async fn g(arg: *const impl Trait); } } + +// https://github.com/dtolnay/async-trait/issues/210 +pub mod issue210 { + use async_trait::async_trait; + use std::sync::Arc; + + #[async_trait] + pub trait Trait { + async fn f(self: Arc) {} + } +}