From b35f32f39bb9d9dd3dc1c256be4f6fd858143c31 Mon Sep 17 00:00:00 2001 From: Steve Fan <29133953+stevefan1999-personal@users.noreply.github.com> Date: Mon, 15 Jan 2024 22:17:36 +0800 Subject: [PATCH] try to add enum explicit variant support once again --- serde/src/private/ser.rs | 31 +- serde_derive/src/de.rs | 428 ++++++++++++++++++++++----- serde_derive/src/internals/attr.rs | 54 +++- serde_derive/src/internals/symbol.rs | 13 +- serde_derive/src/ser.rs | 165 +++++++++-- test_suite/tests/test_annotations.rs | 176 +++++++++++ 6 files changed, 757 insertions(+), 110 deletions(-) diff --git a/serde/src/private/ser.rs b/serde/src/private/ser.rs index 50bcb251e..1fa55ec5f 100644 --- a/serde/src/private/ser.rs +++ b/serde/src/private/ser.rs @@ -14,17 +14,18 @@ pub fn constrain(t: &T) -> &T { } /// Not public API. -pub fn serialize_tagged_newtype( +pub fn serialize_tagged_newtype( serializer: S, type_ident: &'static str, variant_ident: &'static str, tag: &'static str, - variant_name: &'static str, + variant_name: I, value: &T, ) -> Result where S: Serializer, T: Serialize, + I: Serialize, { value.serialize(TaggedSerializer { type_ident, @@ -35,11 +36,11 @@ where }) } -struct TaggedSerializer { +struct TaggedSerializer { type_ident: &'static str, variant_ident: &'static str, tag: &'static str, - variant_name: &'static str, + variant_name: I, delegate: S, } @@ -79,9 +80,10 @@ impl Display for Unsupported { } } -impl TaggedSerializer +impl TaggedSerializer where S: Serializer, + I: Serialize, { fn bad_type(self, what: Unsupported) -> S::Error { ser::Error::custom(format_args!( @@ -91,9 +93,10 @@ where } } -impl Serializer for TaggedSerializer +impl Serializer for TaggedSerializer where S: Serializer, + I: Serialize, { type Ok = S::Ok; type Error = S::Error; @@ -183,13 +186,13 @@ where fn serialize_unit(self) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); map.end() } fn serialize_unit_struct(self, _: &'static str) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); map.end() } @@ -200,7 +203,7 @@ where inner_variant: &'static str, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_entry(inner_variant, &())); map.end() } @@ -227,7 +230,7 @@ where T: Serialize, { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_entry(inner_variant, inner_value)); map.end() } @@ -270,7 +273,7 @@ where len: usize, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_key(inner_variant)); Ok(SerializeTupleVariantAsMapValue::new( map, @@ -281,7 +284,7 @@ where fn serialize_map(self, len: Option) -> Result { let mut map = tri!(self.delegate.serialize_map(len.map(|len| len + 1))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); Ok(map) } @@ -291,7 +294,7 @@ where len: usize, ) -> Result { let mut state = tri!(self.delegate.serialize_struct(name, len + 1)); - tri!(state.serialize_field(self.tag, self.variant_name)); + tri!(state.serialize_field(self.tag, &self.variant_name)); Ok(state) } @@ -317,7 +320,7 @@ where len: usize, ) -> Result { let mut map = tri!(self.delegate.serialize_map(Some(2))); - tri!(map.serialize_entry(self.tag, self.variant_name)); + tri!(map.serialize_entry(self.tag, &self.variant_name)); tri!(map.serialize_key(inner_variant)); Ok(SerializeStructVariantAsMapValue::new( map, diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index e3b737c61..ea20f52b2 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -3,6 +3,7 @@ use crate::internals::ast::{Container, Data, Field, Style, Variant}; use crate::internals::{attr, replace_receiver, ungroup, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; use proc_macro2::{Literal, Span, TokenStream}; +use quote::format_ident; use quote::{quote, quote_spanned, ToTokens}; use std::collections::BTreeSet; use std::ptr; @@ -1214,7 +1215,7 @@ fn prepare_enum_variant_enum( variants: &[Variant], cattrs: &attr::Container, ) -> (TokenStream, Stmts) { - let mut deserialized_variants = variants + let deserialized_variants = variants .iter() .enumerate() .filter(|&(_, variant)| !variant.attrs.skip_deserializing()); @@ -1231,6 +1232,7 @@ fn prepare_enum_variant_enum( .collect(); let fallthrough = deserialized_variants + .clone() .position(|(_, variant)| variant.attrs.other()) .map(|other_idx| { let ignore_variant = variant_names_idents[other_idx].1.clone(); @@ -1245,12 +1247,62 @@ fn prepare_enum_variant_enum( } }; + let repr = match (cattrs.use_repr(), cattrs.repr_type().map(|ty| ty.clone())) { + (true, Some(repr_type)) => { + let mut discriminants = Vec::new(); + let mut last_discriminant = None; + let mut iterations_without_discriminant = 0; + for (_, variant) in deserialized_variants.clone() { + let discriminant = variant.original.discriminant.as_ref().map(|(_, d)| d); + let discriminant = if let Some(expr) = discriminant { + last_discriminant = Some(expr); + parse_quote!(#expr) + } else if let Some(expr) = last_discriminant { + match expr { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) => { + let value = lit_int.base10_parse::().unwrap(); + let value = value + iterations_without_discriminant; + let value = syn::Lit::Int(syn::LitInt::new( + &value.to_string(), + Span::call_site(), + )); + parse_quote!(#value) + } + _ => { + let iterations_without_discriminant = syn::Lit::Int(syn::LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#expr + #iterations_without_discriminant) + } + } + } else { + let iterations_without_discriminant = syn::Lit::Int(syn::LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#iterations_without_discriminant) + }; + + discriminants.push(discriminant); + iterations_without_discriminant += 1; + } + + Some((discriminants, repr_type)) + } + _ => None, + }; + let variant_visitor = Stmts(deserialize_generated_identifier( &variant_names_idents, cattrs, true, None, fallthrough, + repr, )); (variants_stmt, variant_visitor) @@ -1989,6 +2041,7 @@ fn deserialize_generated_identifier( is_variant: bool, ignore_variant: Option, fallthrough: Option, + repr: Option<(Vec, syn::Type)>, ) -> Fragment { let this_value = quote!(__Field); let field_idents: &Vec<_> = &fields.iter().map(|(_, ident, _)| ident).collect(); @@ -2001,6 +2054,7 @@ fn deserialize_generated_identifier( None, !is_variant && cattrs.has_flatten(), None, + repr, )); let lifetime = if !is_variant && cattrs.has_flatten() { @@ -2062,6 +2116,7 @@ fn deserialize_field_identifier( false, ignore_variant, fallthrough, + None, )) } @@ -2144,6 +2199,53 @@ fn deserialize_custom_identifier( Some(fields) }; + let repr = match (cattrs.use_repr(), cattrs.repr_type().map(|ty| ty.clone())) { + (true, Some(repr_type)) => { + let mut discriminants = Vec::new(); + + let mut last_discriminant = None; + let mut iterations_without_discriminant = 0; + + for variant in ordinary { + let discriminant = variant.original.discriminant.as_ref().map(|(_, d)| d); + let discriminant = if let Some(expr) = discriminant { + last_discriminant = Some(expr); + parse_quote!(#expr) + } else if let Some(expr) = last_discriminant { + match expr { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) => { + let value = lit_int.base10_parse::().unwrap(); + let value = value + iterations_without_discriminant; + parse_quote!(#value) + } + _ => { + let iterations_without_discriminant = syn::Lit::Int(syn::LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#expr + #iterations_without_discriminant) + } + } + } else { + let iterations_without_discriminant = syn::Lit::Int(syn::LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#iterations_without_discriminant) + }; + + discriminants.push(discriminant); + iterations_without_discriminant += 1; + } + + Some((discriminants, repr_type)) + } + _ => None, + }; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = split_with_de_lifetime(params); let delife = params.borrowed.de_lifetime(); @@ -2155,6 +2257,7 @@ fn deserialize_custom_identifier( fallthrough_borrowed, false, cattrs.expecting(), + repr, )); quote_block! { @@ -2180,6 +2283,34 @@ fn deserialize_custom_identifier( } } +fn filter_repr_by_type( + repr: &Option<(Vec, syn::Type)>, + ty_symbol: &str, +) -> Option> { + if let Some((discriminants, repr_type)) = repr { + match repr_type { + syn::Type::Path(syn::TypePath { + path: syn::Path { segments, .. }, + .. + }) => { + if segments.len() == 1 { + let segment = &segments[0]; + if segment.ident.to_string() == ty_symbol { + Some(discriminants.clone()) + } else { + None + } + } else { + None + } + } + _ => None, + } + } else { + None + } +} + fn deserialize_identifier( this_value: &TokenStream, fields: &[(&str, Ident, &BTreeSet)], @@ -2188,6 +2319,7 @@ fn deserialize_identifier( fallthrough_borrowed: Option, collect_other_fields: bool, expecting: Option<&str>, + repr: Option<(Vec, syn::Type)>, ) -> Fragment { let str_mapping = fields.iter().map(|(_, ident, aliases)| { // `aliases` also contains a main name @@ -2201,12 +2333,30 @@ fn deserialize_identifier( quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident)) }); + let discriminant_identifiers: Vec<_> = fields + .iter() + .enumerate() + .map(|(i, _)| Ident::new(&format!("__DISCRIMINANT_{}", i), Span::call_site())) + .collect(); + + let constructors: &Vec<_> = &fields + .iter() + .map(|(_, ident, _)| quote!(#this_value::#ident)) + .collect(); + + let main_constructors: Vec<_> = fields + .iter() + .map(|(_, ident, _)| quote!(#this_value::#ident)) + .collect(); + let expecting = expecting.unwrap_or(if is_variant { "variant identifier" } else { "field identifier" }); + let index_expecting = if is_variant { "variant" } else { "field" }; + let bytes_to_str = if fallthrough.is_some() || collect_other_fields { None } else { @@ -2240,7 +2390,7 @@ fn deserialize_identifier( }; let fallthrough_arm_tokens; - let fallthrough_arm = if let Some(fallthrough) = &fallthrough { + let fallthrough_arm: &TokenStream = if let Some(fallthrough) = &fallthrough { fallthrough } else if is_variant { fallthrough_arm_tokens = quote! { @@ -2254,69 +2404,232 @@ fn deserialize_identifier( &fallthrough_arm_tokens }; - let visit_other = if collect_other_fields { - quote! { - fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Bool(__value))) - } + let field_str: Option> = if let Some((discriminants, _)) = &repr { + discriminants + .iter() + .map(|expr| match expr { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) => { + let value = lit_int.base10_parse::().unwrap(); + let value = value.to_string(); + Some(syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new(&value, Span::call_site())), + })) + } + _ => None, + }) + .collect() + } else { + Some( + fields + .iter() + .map(|(name, _, _)| { + syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new(*name, Span::call_site())), + }) + }) + .collect(), + ) + }; + let field_byte: Option> = if let Some((discriminants, _)) = &repr { + discriminants + .iter() + .map(|expr| match expr { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) => { + let value = lit_int.base10_parse::().unwrap(); + let value = value.to_string(); + Some(syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::ByteStr(syn::LitByteStr::new( + value.as_bytes(), + Span::call_site(), + )), + })) + } + _ => None, + }) + .collect() + } else { + Some( + fields + .iter() + .map(|(name, _, _)| { + syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::ByteStr(syn::LitByteStr::new( + name.as_bytes(), + Span::call_site(), + )), + }) + }) + .collect(), + ) + }; + let field_strs: &Vec<_> = &fields.iter().map(|(name, _, _)| name).collect(); + let field_bytes: &Vec<_> = &fields + .iter() + .map(|(name, _, _)| Literal::byte_string(name.as_bytes())) + .collect(); - fn visit_i8<__E>(self, __value: i8) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I8(__value))) + let visit_generator = |ty_str, content, signed| { + let ty = format_ident!("{ty_str}"); + let visit_name = format_ident!("visit_{ty}"); + let content = format_ident!("{content}"); + let value = if signed { + quote! { + _serde::de::Unexpected::Signed(__value as i64) } - - fn visit_i16<__E>(self, __value: i16) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I16(__value))) + } else { + quote! { + _serde::de::Unexpected::Unsigned(__value as u64) } + }; - fn visit_i32<__E>(self, __value: i32) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I32(__value))) + if let Some(fields) = filter_repr_by_type(&repr, ty_str) { + let fallthrough_arm_tokens; + let fallthrough = if let Some(fallthrough) = &fallthrough { + fallthrough + } else { + let fallthrough_msg = format!("{} discriminant", index_expecting); + fallthrough_arm_tokens = quote! { + _serde::__private::Err(_serde::de::Error::invalid_value( + #value, + &#fallthrough_msg, + )) + }; + &fallthrough_arm_tokens + }; + + quote! { + fn #visit_name<__E>(self, __value: #ty) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + struct __Discriminants; + impl __Discriminants { + #(const #discriminant_identifiers: #ty = #fields;)* + } + + match __value { + #( + __Discriminants::#discriminant_identifiers => _serde::__private::Ok(#main_constructors), + )* + _ => #fallthrough, + } + } } + } else if collect_other_fields { + quote! { + fn #visit_name<__E>(self, __value: #ty) -> _serde::__private::Result + where + __E: _serde::de::Error, + { + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::#content(__value))) + } + } + } else { + quote! {} + } + }; + + let visit_i8 = visit_generator("i8", "I8", true); + let visit_i16 = visit_generator("i16", "I16", true); + let visit_i32 = visit_generator("i32", "I32", true); + let visit_i64 = visit_generator("i64", "I64", true); + + let visit_u8 = visit_generator("u8", "U8", false); + let visit_u16 = visit_generator("u16", "U16", false); + let visit_u32 = visit_generator("u32", "U32", false); + let visit_u64 = visit_generator("u64", "U64", false); - fn visit_i64<__E>(self, __value: i64) -> _serde::__private::Result + let visit_str = if let Some(field_str) = &field_str { + quote! { + fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I64(__value))) + match __value { + #( + #field_str => _serde::__private::Ok(#constructors), + )* + _ => { + #value_as_str_content + #fallthrough_arm + } + } } - - fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result + } + } else { + quote! { + fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U8(__value))) + match __value { + #( + #field_strs => _serde::__private::Ok(#constructors), + )* + _ => { + #value_as_str_content + #fallthrough_arm + } + } } + } + }; - fn visit_u16<__E>(self, __value: u16) -> _serde::__private::Result + let visit_bytes = if let Some(field_byte) = &field_byte { + quote! { + fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U16(__value))) + match __value { + #( + #field_byte => _serde::__private::Ok(#constructors), + )* + _ => { + #bytes_to_str + #value_as_bytes_content + #fallthrough_arm + } + } } - - fn visit_u32<__E>(self, __value: u32) -> _serde::__private::Result + } + } else { + quote! { + fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U32(__value))) + match __value { + #( + #field_bytes => _serde::__private::Ok(#constructors), + )* + _ => { + #bytes_to_str + #value_as_bytes_content + #fallthrough_arm + } + } } + } + }; - fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result + let visit_other = if collect_other_fields { + quote! { + fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result where __E: _serde::de::Error, { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U64(__value))) + _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Bool(__value))) } fn visit_f32<__E>(self, __value: f32) -> _serde::__private::Result @@ -2357,7 +2670,6 @@ fn deserialize_identifier( let u64_fallthrough_arm = if let Some(fallthrough) = &fallthrough { fallthrough } else { - let index_expecting = if is_variant { "variant" } else { "field" }; let fallthrough_msg = format!("{} index 0 <= i < {}", index_expecting, fields.len()); u64_fallthrough_arm_tokens = quote! { _serde::__private::Err(_serde::de::Error::invalid_value( @@ -2422,35 +2734,17 @@ fn deserialize_identifier( _serde::__private::Formatter::write_str(__formatter, #expecting) } + #visit_i8 + #visit_u8 + #visit_i16 + #visit_u16 + #visit_i32 + #visit_u32 + #visit_i64 + #visit_u64 #visit_other - - fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#str_mapping,)* - _ => { - #value_as_str_content - #fallthrough_arm - } - } - } - - fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#bytes_mapping,)* - _ => { - #bytes_to_str - #value_as_bytes_content - #fallthrough_arm - } - } - } - + #visit_str + #visit_bytes #visit_borrowed } } diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index bb9de328a..f7e4b733c 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -8,7 +8,7 @@ use std::iter::FromIterator; use syn::meta::ParseNestedMeta; use syn::parse::ParseStream; use syn::punctuated::Punctuated; -use syn::{parse_quote, token, Ident, Lifetime, Token}; +use syn::{parse_quote, token, Ident, Lifetime, Meta, Token}; // This module handles parsing of `#[serde(...)]` attributes. The entrypoints // are `attr::Container::from_ast`, `attr::Variant::from_ast`, and @@ -219,6 +219,8 @@ pub struct Container { has_flatten: bool, serde_path: Option, is_packed: bool, + repr_type: Option, + use_repr: bool, /// Error message generated when type can't be deserialized expecting: Option, non_exhaustive: bool, @@ -307,6 +309,7 @@ impl Container { let mut variant_identifier = BoolAttr::none(cx, VARIANT_IDENTIFIER); let mut serde_path = Attr::none(cx, CRATE); let mut expecting = Attr::none(cx, EXPECTING); + let mut use_repr = BoolAttr::none(cx, USE_REPR); let mut non_exhaustive = false; for attr in &item.attrs { @@ -530,6 +533,9 @@ impl Container { } else if meta.path == VARIANT_IDENTIFIER { // #[serde(variant_identifier)] variant_identifier.set_true(&meta.path); + } else if meta.path == USE_REPR { + // #[serde(use_repr)] + use_repr.set_true(&meta.path); } else if meta.path == CRATE { // #[serde(crate = "foo")] if let Some(path) = parse_lit_into_path(cx, CRATE, &meta)? { @@ -566,6 +572,14 @@ impl Container { } } + let repr_type = item + .attrs + .iter() + .flat_map(|attr| get_repr_type(cx, attr)) + .filter(Option::is_some) + .map(Option::unwrap) + .next(); + Container { name: Name::from_attrs(unraw(&item.ident), ser_name, de_name, None), transparent: transparent.get(), @@ -590,6 +604,8 @@ impl Container { has_flatten: false, serde_path: serde_path.get(), is_packed, + repr_type, + use_repr: use_repr.get(), expecting: expecting.get(), non_exhaustive, } @@ -672,6 +688,14 @@ impl Container { .map_or_else(|| Cow::Owned(parse_quote!(_serde)), Cow::Borrowed) } + pub fn repr_type(&self) -> Option<&syn::Type> { + self.repr_type.as_ref() + } + + pub fn use_repr(&self) -> bool { + self.use_repr + } + /// Error message generated when type can't be deserialized. /// If `None`, default message will be used pub fn expecting(&self) -> Option<&str> { @@ -1460,6 +1484,34 @@ fn get_where_predicates( Ok((ser.at_most_one(), de.at_most_one())) } +pub fn get_repr_type(cx: &Ctxt, attr: &syn::Attribute) -> Result, ()> { + fn get_repr_type_inner(attr: &syn::Attribute) -> syn::Result> { + if attr.path() == REPR { + let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; + + for meta in nested { + if let Meta::Path(path) = meta { + if let Ok(ty_) = syn::parse::(path.to_token_stream().into()) { + if matches!( + ty_.to_string().as_str(), + "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" + ) { + return Ok(Some(syn::parse(ty_.to_token_stream().into())?)); + } + } + } + } + } + + Ok(None) + } + + get_repr_type_inner(attr).map_err(|err| { + cx.syn_error(err); + () + }) +} + fn get_lit_str( cx: &Ctxt, attr_name: Symbol, diff --git a/serde_derive/src/internals/symbol.rs b/serde_derive/src/internals/symbol.rs index 572391a80..8ce013f8d 100644 --- a/serde_derive/src/internals/symbol.rs +++ b/serde_derive/src/internals/symbol.rs @@ -11,8 +11,8 @@ pub const CONTENT: Symbol = Symbol("content"); pub const CRATE: Symbol = Symbol("crate"); pub const DEFAULT: Symbol = Symbol("default"); pub const DENY_UNKNOWN_FIELDS: Symbol = Symbol("deny_unknown_fields"); -pub const DESERIALIZE: Symbol = Symbol("deserialize"); pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with"); +pub const DESERIALIZE: Symbol = Symbol("deserialize"); pub const EXPECTING: Symbol = Symbol("expecting"); pub const FIELD_IDENTIFIER: Symbol = Symbol("field_identifier"); pub const FLATTEN: Symbol = Symbol("flatten"); @@ -22,21 +22,22 @@ pub const INTO: Symbol = Symbol("into"); pub const NON_EXHAUSTIVE: Symbol = Symbol("non_exhaustive"); pub const OTHER: Symbol = Symbol("other"); pub const REMOTE: Symbol = Symbol("remote"); -pub const RENAME: Symbol = Symbol("rename"); -pub const RENAME_ALL: Symbol = Symbol("rename_all"); pub const RENAME_ALL_FIELDS: Symbol = Symbol("rename_all_fields"); +pub const RENAME_ALL: Symbol = Symbol("rename_all"); +pub const RENAME: Symbol = Symbol("rename"); pub const REPR: Symbol = Symbol("repr"); pub const SERDE: Symbol = Symbol("serde"); -pub const SERIALIZE: Symbol = Symbol("serialize"); pub const SERIALIZE_WITH: Symbol = Symbol("serialize_with"); -pub const SKIP: Symbol = Symbol("skip"); +pub const SERIALIZE: Symbol = Symbol("serialize"); pub const SKIP_DESERIALIZING: Symbol = Symbol("skip_deserializing"); -pub const SKIP_SERIALIZING: Symbol = Symbol("skip_serializing"); pub const SKIP_SERIALIZING_IF: Symbol = Symbol("skip_serializing_if"); +pub const SKIP_SERIALIZING: Symbol = Symbol("skip_serializing"); +pub const SKIP: Symbol = Symbol("skip"); pub const TAG: Symbol = Symbol("tag"); pub const TRANSPARENT: Symbol = Symbol("transparent"); pub const TRY_FROM: Symbol = Symbol("try_from"); pub const UNTAGGED: Symbol = Symbol("untagged"); +pub const USE_REPR: Symbol = Symbol("use_repr"); pub const VARIANT_IDENTIFIER: Symbol = Symbol("variant_identifier"); pub const WITH: Symbol = Symbol("with"); diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 3be51ee52..9f2124b36 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -3,9 +3,10 @@ use crate::internals::ast::{Container, Data, Field, Style, Variant}; use crate::internals::{attr, replace_receiver, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; use proc_macro2::{Span, TokenStream}; +use quote::ToTokens; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; -use syn::{parse_quote, Ident, Index, Member}; +use syn::{parse_quote, Ident, Index, Lit, LitInt, LitStr, Member}; pub fn expand_derive_serialize(input: &mut syn::DeriveInput) -> syn::Result { replace_receiver(input); @@ -401,13 +402,50 @@ fn serialize_enum(params: &Parameters, variants: &[Variant], cattrs: &attr::Cont let self_var = ¶ms.self_var; - let mut arms: Vec<_> = variants - .iter() - .enumerate() - .map(|(variant_index, variant)| { - serialize_variant(params, variant, variant_index as u32, cattrs) - }) - .collect(); + let mut arms = Vec::new(); + let mut last_discriminant = None; + let mut iterations_without_discriminant = 0; + for (variant_index, variant) in variants.iter().enumerate() { + let discriminant = variant.original.discriminant.as_ref().map(|(_, d)| d); + let discriminant = if let Some(expr) = discriminant { + last_discriminant = Some(expr); + parse_quote!(#expr) + } else if let Some(expr) = last_discriminant { + match expr { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) => { + let value = lit_int.base10_parse::().unwrap(); + let value = value + iterations_without_discriminant; + parse_quote!(#value) + } + _ => { + let iterations_without_discriminant = Lit::Int(LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#expr + #iterations_without_discriminant) + } + } + } else { + let iterations_without_discriminant = Lit::Int(LitInt::new( + &iterations_without_discriminant.to_string(), + Span::call_site(), + )); + parse_quote!(#iterations_without_discriminant) + }; + + arms.push(serialize_variant( + params, + variant, + variant_index as u32, + &discriminant, + cattrs, + )); + + iterations_without_discriminant += 1; + } if cattrs.remote().is_some() && cattrs.non_exhaustive() { arms.push(quote! { @@ -426,6 +464,7 @@ fn serialize_variant( params: &Parameters, variant: &Variant, variant_index: u32, + variant_discriminant: &syn::Expr, cattrs: &attr::Container, ) -> TokenStream { let this_value = ¶ms.this_value; @@ -477,18 +516,27 @@ fn serialize_variant( }; let body = Match(match (cattrs.tag(), variant.attrs.untagged()) { - (attr::TagType::External, false) => { - serialize_externally_tagged_variant(params, variant, variant_index, cattrs) - } - (attr::TagType::Internal { tag }, false) => { - serialize_internally_tagged_variant(params, variant, cattrs, tag) - } + (attr::TagType::External, false) => serialize_externally_tagged_variant( + params, + variant, + variant_index, + variant_discriminant, + cattrs, + ), + (attr::TagType::Internal { tag }, false) => serialize_internally_tagged_variant( + params, + variant, + variant_discriminant, + cattrs, + tag, + ), (attr::TagType::Adjacent { tag, content }, false) => { serialize_adjacently_tagged_variant( params, variant, cattrs, variant_index, + variant_discriminant, tag, content, ) @@ -508,10 +556,31 @@ fn serialize_externally_tagged_variant( params: &Parameters, variant: &Variant, variant_index: u32, + variant_discriminant: &syn::Expr, cattrs: &attr::Container, ) -> Fragment { let type_name = cattrs.name().serialize_name(); - let variant_name = variant.attrs.name().serialize_name(); + let variant_name = variant.attrs.name().serialize_name().to_string(); + let variant_name = match (cattrs.use_repr(), variant_discriminant) { + (true, syn::Expr::Lit(lit)) => { + match lit.lit { + syn::Lit::Int(ref i) => { + if let Ok(i) = i.base10_parse::() { + i.to_string() + } else { + variant_name + } + }, + _ => panic!("Can only use external tagging with integer discriminants but found non integer type"), + } + } + (false, _) => variant_name, + _ => panic!( + "Can only use external tagging with integer discriminants but found expression, {:?}", + variant_discriminant.to_token_stream() + ), // TODO + }; + let variant_name = variant_name.as_str(); if let Some(path) = variant.attrs.serialize_with() { let ser = wrap_serialize_variant_with(params, path, variant); @@ -580,12 +649,38 @@ fn serialize_externally_tagged_variant( fn serialize_internally_tagged_variant( params: &Parameters, variant: &Variant, + variant_discriminant: &syn::Expr, cattrs: &attr::Container, tag: &str, ) -> Fragment { let type_name = cattrs.name().serialize_name(); let variant_name = variant.attrs.name().serialize_name(); + let variant_name = &match (cattrs.use_repr(), variant_discriminant) { + (true, syn::Expr::Lit(lit)) => { + match lit.lit { + syn::Lit::Int(ref i) if i.base10_parse::().is_ok() => { + if let Some(repr_type) = cattrs.repr_type() { + parse_quote! { ((#variant_discriminant) as #repr_type) } + } else { + panic!("No #[repr(...)] attribute found"); + } + + }, + _ => panic!("Can only use external tagging with integer discriminants but found non integer type"), + } + } + (false, _) => { + syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(LitStr::new(&variant_name, Span::call_site())), + }) + }, + _ => panic!( + "Can only use external tagging with integer discriminants but found expression, {:?}", + variant_discriminant.to_token_stream() + ), // TODO + }; let enum_ident_str = params.type_name(); let variant_ident_str = variant.ident.to_string(); @@ -609,7 +704,7 @@ fn serialize_internally_tagged_variant( let mut __struct = _serde::Serializer::serialize_struct( __serializer, #type_name, 1)?; _serde::ser::SerializeStruct::serialize_field( - &mut __struct, #tag, #variant_name)?; + &mut __struct, #tag, &#variant_name)?; _serde::ser::SerializeStruct::end(__struct) } } @@ -648,17 +743,39 @@ fn serialize_adjacently_tagged_variant( variant: &Variant, cattrs: &attr::Container, variant_index: u32, + variant_discriminant: &syn::Expr, tag: &str, content: &str, ) -> Fragment { let this_type = ¶ms.this_type; let type_name = cattrs.name().serialize_name(); - let variant_name = variant.attrs.name().serialize_name(); + let variant_name = variant.attrs.name().serialize_name().to_string(); + let variant_identifier = match (cattrs.use_repr(), variant_discriminant) { + (true, syn::Expr::Lit(lit)) => { + match lit.lit { + syn::Lit::Int(ref i) => { + if let Ok(i) = i.base10_parse::() { + i.to_string() + } else { + variant_name.clone() + } + }, + _ => panic!("Can only use external tagging with integer discriminants but found non integer type"), + } + } + (false, _) => variant_name.clone(), + _ => panic!( + "Can only use external tagging with integer discriminants but found expression, {:?}", + variant_discriminant.to_token_stream() + ), // TODO + }; + let variant_name = variant_name.as_str(); + let serialize_variant = quote! { &_serde::__private::ser::AdjacentlyTaggedEnumVariant { enum_name: #type_name, variant_index: #variant_index, - variant_name: #variant_name, + variant_name: &#variant_identifier, } }; @@ -880,7 +997,7 @@ enum StructVariant<'a> { }, InternallyTagged { tag: &'a str, - variant_name: &'a str, + variant_name: &'a syn::Expr, }, Untagged, } @@ -939,7 +1056,9 @@ fn serialize_struct_variant( _serde::ser::SerializeStructVariant::end(__serde_state) } } - StructVariant::InternallyTagged { tag, variant_name } => { + StructVariant::InternallyTagged { + tag, variant_name, .. + } => { quote_block! { let mut __serde_state = _serde::Serializer::serialize_struct( __serializer, @@ -949,7 +1068,7 @@ fn serialize_struct_variant( _serde::ser::SerializeStruct::serialize_field( &mut __serde_state, #tag, - #variant_name, + &#variant_name, )?; #(#serialize_fields)* _serde::ser::SerializeStruct::end(__serde_state) @@ -1030,7 +1149,9 @@ fn serialize_struct_variant_with_flatten( }) } } - StructVariant::InternallyTagged { tag, variant_name } => { + StructVariant::InternallyTagged { + tag, variant_name, .. + } => { quote_block! { let #let_mut __serde_state = _serde::Serializer::serialize_map( __serializer, diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 566f7d43f..8de60e87d 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -3232,3 +3232,179 @@ mod flatten { } } } + +#[test] +fn test_externally_tagged_enum_with_repr() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(use_repr)] + #[repr(u8)] + enum ReprExternalEnum { + A { a: u32 } = 0x42, + B(u32), + C = 0x52, + D(u32, u32) + } + + assert_tokens( + &ReprExternalEnum::A { a: 42 }, + &[ + Token::StructVariant { + name: "ReprExternalEnum", + variant: "66", + len: 1, + }, + Token::Str("a"), + Token::U32(42), + Token::StructVariantEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::B(0), + &[ + Token::NewtypeVariant { + name: "ReprExternalEnum", + variant: "67", + }, + Token::U32(0), + ], + ); + assert_tokens( + &ReprExternalEnum::C, + &[Token::UnitVariant { + name: "ReprExternalEnum", + variant: "82", + }], + ); +} + +#[test] +fn test_internally_tagged_enum_with_repr() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(use_repr, tag = "tag")] + #[repr(u8)] + enum ReprExternalEnum { + A { a: u32 } = 0x42, + B(u32) = 0x43, + C = 0x52, + } + + assert_tokens( + &ReprExternalEnum::A { a: 0 }, + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 2, + }, + Token::Str("tag"), + Token::U8(0x42), + Token::Str("a"), + Token::U32(0), + Token::StructEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::B(1337), + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 2, + }, + Token::Str("tag"), + Token::U8(0x43), + Token::Seq { len: Some(1) }, + Token::U32(1337), + Token::SeqEnd, + Token::StructEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::C, + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 1, + }, + Token::Str("tag"), + Token::U8(0x52), + Token::StructEnd, + ], + ); +} + +#[test] +fn test_adjacently_tagged_enum_with_repr() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(use_repr, tag = "tag", content = "content")] + #[repr(u8)] + enum ReprExternalEnum { + A { a: u32 } = 0x42, + B(u32), + C = 0x52, + D(u32, u32), + } + + assert_tokens( + &ReprExternalEnum::A { a: 0 }, + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 2, + }, + Token::Str("tag"), + Token::Str("66"), + + Token::Str("content"), + Token::Struct { name: "A", len: 1 }, + Token::Str("a"), + Token::U32(0), + Token::StructEnd, + + Token::StructEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::B(0), + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 2, + }, + Token::Str("tag"), + Token::Str("67"), + Token::Str("content"), + Token::U32(0), + Token::StructEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::C, + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 1, + }, + Token::Str("tag"), + Token::Str("82"), + Token::StructEnd, + ], + ); + assert_tokens( + &ReprExternalEnum::D(1, 2), + &[ + Token::Struct { + name: "ReprExternalEnum", + len: 2, + }, + Token::Str("tag"), + Token::Str("83"), + + Token::Str("content"), + Token::Seq { len: Some(2) }, + Token::U32(1), + Token::U32(2), + Token::SeqEnd, + + Token::StructEnd, + ], + ); +}