From 6c4a0f06b9475282b13451e1411afa7b6cc757a3 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Sat, 24 Oct 2020 05:20:52 +0900 Subject: [PATCH] Fix handling of Self keyword in type definition --- serde_derive/Cargo.toml | 2 +- serde_derive/src/de.rs | 8 +- serde_derive/src/internals/mod.rs | 4 + serde_derive/src/internals/receiver.rs | 186 +++++++++++++++++++++++++ serde_derive/src/internals/respan.rs | 22 +++ serde_derive/src/lib.rs | 4 +- serde_derive_internals/Cargo.toml | 2 +- test_suite/tests/test_self.rs | 101 ++++++++++++++ 8 files changed, 323 insertions(+), 6 deletions(-) create mode 100644 serde_derive/src/internals/receiver.rs create mode 100644 serde_derive/src/internals/respan.rs create mode 100644 test_suite/tests/test_self.rs diff --git a/serde_derive/Cargo.toml b/serde_derive/Cargo.toml index 4dfdd92b18..04ac7fd853 100644 --- a/serde_derive/Cargo.toml +++ b/serde_derive/Cargo.toml @@ -22,7 +22,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "1.0.33", features = ["visit"] } +syn = { version = "1.0.33", features = ["visit", "visit-mut"] } [dev-dependencies] serde = { version = "1.0", path = "../serde" } diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 1f5733a6d5..d6684fd64a 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -8,12 +8,16 @@ use bound; use dummy; use fragment::{Expr, Fragment, Match, Stmts}; use internals::ast::{Container, Data, Field, Style, Variant}; -use internals::{attr, ungroup, Ctxt, Derive}; +use internals::{attr, replace_receiver, ungroup, Ctxt, Derive}; use pretend; use std::collections::BTreeSet; -pub fn expand_derive_deserialize(input: &syn::DeriveInput) -> Result> { +pub fn expand_derive_deserialize( + input: &mut syn::DeriveInput, +) -> Result> { + replace_receiver(input); + let ctxt = Ctxt::new(); let cont = match Container::from_ast(&ctxt, input, Derive::Deserialize) { Some(cont) => cont, diff --git a/serde_derive/src/internals/mod.rs b/serde_derive/src/internals/mod.rs index d36b6e45c4..5e9f416c46 100644 --- a/serde_derive/src/internals/mod.rs +++ b/serde_derive/src/internals/mod.rs @@ -4,8 +4,12 @@ pub mod attr; mod ctxt; pub use self::ctxt::Ctxt; +mod receiver; +pub use self::receiver::replace_receiver; + mod case; mod check; +mod respan; mod symbol; use syn::Type; diff --git a/serde_derive/src/internals/receiver.rs b/serde_derive/src/internals/receiver.rs new file mode 100644 index 0000000000..0a7e4b789c --- /dev/null +++ b/serde_derive/src/internals/receiver.rs @@ -0,0 +1,186 @@ +use super::respan::respan; +use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree}; +use quote::{quote, quote_spanned}; +use std::{iter::FromIterator, mem}; +use syn::{ + parse_quote, + punctuated::Punctuated, + visit_mut::{self, VisitMut}, + DeriveInput, ExprPath, Macro, Path, PathArguments, QSelf, Type, TypePath, +}; + +pub fn replace_receiver(input: &mut DeriveInput) { + let self_ty = { + let ident = &input.ident; + let ty_generics = input.generics.split_for_impl().1; + parse_quote!(#ident #ty_generics) + }; + let mut visitor = ReplaceReceiver(&self_ty); + visitor.visit_generics_mut(&mut input.generics); + visitor.visit_data_mut(&mut input.data); +} + +struct ReplaceReceiver<'a>(&'a TypePath); + +impl ReplaceReceiver<'_> { + fn self_ty(&self, span: Span) -> TypePath { + respan(self.0, span) + } + + fn self_to_qself(&self, qself: &mut Option, path: &mut Path) { + if path.leading_colon.is_some() { + return; + } + + // Make borrow checker happy + { + let first = &path.segments[0]; + if first.ident != "Self" || !first.arguments.is_empty() { + return; + } + } + + if path.segments.len() == 1 { + self.self_to_expr_path(path); + return; + } + + let span = path.segments[0].ident.span(); + *qself = Some(QSelf { + lt_token: Token![<](span), + ty: Box::new(self.self_ty(span).into()), + position: 0, + as_token: None, + gt_token: Token![>](span), + }); + + path.leading_colon = Some(**path.segments.pairs().next().unwrap().punct().unwrap()); + + let segments = mem::replace(&mut path.segments, Punctuated::new()); + path.segments = segments.into_pairs().skip(1).collect(); + } + + fn self_to_expr_path(&self, path: &mut Path) { + if path.leading_colon.is_some() { + return; + } + + // Make borrow checker happy + { + let first = &path.segments[0]; + if first.ident != "Self" || !first.arguments.is_empty() { + return; + } + } + + let self_ty = self.self_ty(path.segments[0].ident.span()); + let variant = mem::replace(path, self_ty.path); + for segment in &mut path.segments { + if let PathArguments::AngleBracketed(bracketed) = &mut segment.arguments { + if bracketed.colon2_token.is_none() && !bracketed.args.is_empty() { + bracketed.colon2_token = Some(Default::default()); + } + } + } + if variant.segments.len() > 1 { + path.segments.push_punct(Default::default()); + path.segments.extend(variant.segments.into_pairs().skip(1)); + } + } + + fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool { + let mut out = Vec::new(); + let mut modified = false; + let mut iter = tokens.clone().into_iter().peekable(); + while let Some(tt) = iter.next() { + match tt { + TokenTree::Ident(ident) => { + if ident == "Self" { + modified = true; + let self_ty = self.self_ty(ident.span()); + match iter.peek() { + Some(TokenTree::Punct(p)) + if p.as_char() == ':' && p.spacing() == Spacing::Joint => + { + let next = iter.next().unwrap(); + match iter.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == ':' => { + let span = ident.span(); + out.extend(quote_spanned!(span=> <#self_ty>)); + } + _ => out.extend(quote!(#self_ty)), + } + out.push(next); + } + _ => out.extend(quote!(#self_ty)), + } + } else { + out.push(TokenTree::Ident(ident)); + } + } + TokenTree::Group(group) => { + let mut content = group.stream(); + modified |= self.visit_token_stream(&mut content); + let mut new = Group::new(group.delimiter(), content); + new.set_span(group.span()); + out.push(TokenTree::Group(new)); + } + other => out.push(other), + } + } + if modified { + *tokens = TokenStream::from_iter(out); + } + modified + } +} + +impl VisitMut for ReplaceReceiver<'_> { + // `Self` -> `Receiver` + fn visit_type_mut(&mut self, ty: &mut Type) { + if let Type::Path(node) = ty { + if node.qself.is_none() && node.path.is_ident("Self") { + *ty = self.self_ty(node.path.segments[0].ident.span()).into(); + } else { + self.visit_type_path_mut(node); + } + } else { + visit_mut::visit_type_mut(self, ty); + } + } + + // `Self::Assoc` -> `::Assoc` + fn visit_type_path_mut(&mut self, ty: &mut TypePath) { + if ty.qself.is_none() { + self.self_to_qself(&mut ty.qself, &mut ty.path); + } + visit_mut::visit_type_path_mut(self, ty); + } + + // `Self::method` -> `::method` + fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) { + if expr.qself.is_none() { + self.self_to_qself(&mut expr.qself, &mut expr.path); + } + visit_mut::visit_expr_path_mut(self, expr); + } + + fn visit_macro_mut(&mut self, mac: &mut Macro) { + // We can't tell in general whether `self` inside a macro invocation + // refers to the self in the argument list or a different self + // introduced within the macro. Heuristic: if the macro input contains + // `fn`, then `self` is more likely to refer to something other than the + // outer function's self argument. + if !contains_fn(mac.tokens.clone()) { + self.visit_token_stream(&mut mac.tokens); + } + } +} + +fn contains_fn(tokens: TokenStream) -> bool { + tokens.into_iter().any(|tt| match tt { + TokenTree::Ident(ident) => ident == "fn", + TokenTree::Group(group) => contains_fn(group.stream()), + _ => false, + }) +} diff --git a/serde_derive/src/internals/respan.rs b/serde_derive/src/internals/respan.rs new file mode 100644 index 0000000000..38f6612c41 --- /dev/null +++ b/serde_derive/src/internals/respan.rs @@ -0,0 +1,22 @@ +use proc_macro2::{Span, TokenStream}; +use quote::ToTokens; +use syn::parse::Parse; + +pub(crate) fn respan(node: &T, span: Span) -> T +where + T: ToTokens + Parse, +{ + let tokens = node.to_token_stream(); + let respanned = respan_tokens(tokens, span); + syn::parse2(respanned).unwrap() +} + +fn respan_tokens(tokens: TokenStream, span: Span) -> TokenStream { + tokens + .into_iter() + .map(|mut token| { + token.set_span(span); + token + }) + .collect() +} diff --git a/serde_derive/src/lib.rs b/serde_derive/src/lib.rs index 1711340bad..abe7ca9997 100644 --- a/serde_derive/src/lib.rs +++ b/serde_derive/src/lib.rs @@ -86,8 +86,8 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { #[proc_macro_derive(Deserialize, attributes(serde))] pub fn derive_deserialize(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - de::expand_derive_deserialize(&input) + let mut input = parse_macro_input!(input as DeriveInput); + de::expand_derive_deserialize(&mut input) .unwrap_or_else(to_compile_errors) .into() } diff --git a/serde_derive_internals/Cargo.toml b/serde_derive_internals/Cargo.toml index 64ee1c6297..6e92145587 100644 --- a/serde_derive_internals/Cargo.toml +++ b/serde_derive_internals/Cargo.toml @@ -16,7 +16,7 @@ path = "lib.rs" [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "1.0.33", default-features = false, features = ["derive", "parsing", "printing", "clone-impls"] } +syn = { version = "1.0.33", default-features = false, features = ["derive", "parsing", "printing", "clone-impls", "visit-mut"] } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/test_suite/tests/test_self.rs b/test_suite/tests/test_self.rs new file mode 100644 index 0000000000..d2241749c2 --- /dev/null +++ b/test_suite/tests/test_self.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +#[test] +fn test_self() { + macro_rules! mac { + ($($tt:tt)*) => { + $($tt)* + }; + } + + pub trait Trait { + type Assoc; + } + + #[derive(Deserialize, Serialize)] + pub struct Generics> + where + Self: Trait, + ::Assoc: Sized, + mac!(Self): Trait, + { + _f: T, + } + + impl> Trait for Generics { + type Assoc = Self; + } + + #[derive(Deserialize, Serialize)] + pub struct Struct { + _f1: Box, + _f2: Box<::Assoc>, + _f3: Box, + _f4: [(); Self::ASSOC], + _f5: [(); Self::assoc()], + _f6: [(); mac!(Self::assoc())], + } + + impl Struct { + const ASSOC: usize = 1; + const fn assoc() -> usize { + 0 + } + } + + impl Trait for Struct { + type Assoc = Self; + } + + #[derive(Deserialize, Serialize)] + struct Tuple( + Box, + Box<::Assoc>, + Box, + [(); Self::ASSOC], + [(); Self::assoc()], + [(); mac!(Self::assoc())], + ); + + impl Tuple { + const ASSOC: usize = 1; + const fn assoc() -> usize { + 0 + } + } + + impl Trait for Tuple { + type Assoc = Self; + } + + #[derive(Deserialize, Serialize)] + enum Enum { + Struct { + _f1: Box, + _f2: Box<::Assoc>, + _f3: Box, + _f4: [(); Self::ASSOC], + _f5: [(); Self::assoc()], + _f6: [(); mac!(Self::assoc())], + }, + Tuple( + Box, + Box<::Assoc>, + Box, + [(); Self::ASSOC], + [(); Self::assoc()], + [(); mac!(Self::assoc())], + ), + } + + impl Enum { + const ASSOC: usize = 1; + const fn assoc() -> usize { + 0 + } + } + + impl Trait for Enum { + type Assoc = Self; + } +}