From fd5b9d43d063318b757db9d3895c9bed8db91b21 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Mon, 28 Nov 2022 23:14:26 -0800 Subject: [PATCH] Create enum for the different possible inferred bounds --- src/bound.rs | 46 ++++++++++++++++++++++++++++++++++++++++++++++ src/expand.rs | 29 ++++++----------------------- src/lib.rs | 1 + 3 files changed, 53 insertions(+), 23 deletions(-) create mode 100644 src/bound.rs diff --git a/src/bound.rs b/src/bound.rs new file mode 100644 index 0000000..f0a79c4 --- /dev/null +++ b/src/bound.rs @@ -0,0 +1,46 @@ +use proc_macro2::{Ident, Span}; +use syn::punctuated::Punctuated; +use syn::{Token, TypeParamBound}; + +pub type Supertraits = Punctuated; + +pub enum InferredBound { + Send, + Sync, +} + +pub fn has_bound(supertraits: &Supertraits, bound: &InferredBound) -> bool { + for supertrait in supertraits { + if let TypeParamBound::Trait(supertrait) = supertrait { + if supertrait.path.is_ident(bound) + || supertrait.path.segments.len() == 3 + && (supertrait.path.segments[0].ident == "std" + || supertrait.path.segments[0].ident == "core") + && supertrait.path.segments[1].ident == "marker" + && supertrait.path.segments[2].ident == *bound + { + return true; + } + } + } + false +} + +impl InferredBound { + fn as_str(&self) -> &str { + match self { + InferredBound::Send => "Send", + InferredBound::Sync => "Sync", + } + } + + pub fn spanned_ident(&self, span: Span) -> Ident { + Ident::new(self.as_str(), span) + } +} + +impl PartialEq for Ident { + fn eq(&self, bound: &InferredBound) -> bool { + self == bound.as_str() + } +} diff --git a/src/expand.rs b/src/expand.rs index b0dea84..ce99c6f 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -1,3 +1,4 @@ +use crate::bound::{has_bound, InferredBound, Supertraits}; use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes}; use crate::parse::Item; use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf}; @@ -10,7 +11,7 @@ use syn::visit_mut::{self, VisitMut}; use syn::{ parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, Receiver, ReturnType, Signature, Stmt, Token, - TraitItem, Type, TypeParamBound, TypePath, WhereClause, + TraitItem, Type, TypePath, WhereClause, }; impl ToTokens for Item { @@ -51,8 +52,6 @@ impl Context<'_> { } } -type Supertraits = Punctuated; - pub fn expand(input: &mut Item, is_local: bool) { match input { Item::Trait(input) => { @@ -235,7 +234,7 @@ fn transform_sig( reference: Some(_), mutability: None, .. - })) => Ident::new("Sync", default_span), + })) => InferredBound::Sync, Some(FnArg::Typed(arg)) if match (arg.pat.as_ref(), arg.ty.as_ref()) { (Pat::Ident(pat), Type::Reference(ty)) => { @@ -244,9 +243,9 @@ fn transform_sig( _ => false, } => { - Ident::new("Sync", default_span) + InferredBound::Sync } - _ => Ident::new("Send", default_span), + _ => InferredBound::Send, }; let assume_bound = match context { @@ -258,6 +257,7 @@ fn transform_sig( where_clause.predicates.push(if assume_bound || is_local { parse_quote_spanned!(default_span=> Self: 'async_trait) } else { + let bound = bound.spanned_ident(default_span); parse_quote_spanned!(default_span=> Self: ::core::marker::#bound + 'async_trait) }); } @@ -402,23 +402,6 @@ fn positional_arg(i: usize, pat: &Pat) -> Ident { format_ident!("__arg{}", i, span = span) } -fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool { - for bound in supertraits { - if let TypeParamBound::Trait(bound) = bound { - if bound.path.is_ident(marker) - || bound.path.segments.len() == 3 - && (bound.path.segments[0].ident == "std" - || bound.path.segments[0].ident == "core") - && bound.path.segments[1].ident == "marker" - && bound.path.segments[2].ident == *marker - { - return true; - } - } - } - false -} - fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool { struct AssociatedTypeImplTraits<'a> { set: &'a Set, diff --git a/src/lib.rs b/src/lib.rs index 2f8ffcb..8ca5032 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -318,6 +318,7 @@ extern crate proc_macro; mod args; +mod bound; mod expand; mod lifetime; mod parse;