Skip to content

Commit

Permalink
Refactor derive
Browse files Browse the repository at this point in the history
Summary:
Move some common code to "util".

Need to do more derives, back up these changes.

Reviewed By: JakobDegen

Differential Revision: D56123520

fbshipit-source-id: 37d5ea992003d1f1384c331ef98801bde5a1ebba
  • Loading branch information
stepancheg authored and facebook-github-bot committed Apr 15, 2024
1 parent e800fcc commit e2dd0d6
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 106 deletions.
26 changes: 10 additions & 16 deletions starlark_derive/src/any_lifetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::DeriveInput;

use crate::util::DeriveInputUtil;
use crate::util::GenericsUtil;

fn punctuated_try_map<A, B, P: Clone>(
punctuated: &Punctuated<A, P>,
f: impl Fn(&A) -> syn::Result<B>,
Expand Down Expand Up @@ -126,21 +129,11 @@ pub(crate) fn derive_provides_static_type(
}

/// Single lifetime parameter for `ProvidesStaticType`
fn pst_lifetime<'a>(
params: impl Iterator<Item = &'a syn::GenericParam>,
) -> syn::Result<syn::Lifetime> {
let mut lifetime = None;
for param in params {
if let syn::GenericParam::Lifetime(param) = param {
if lifetime.is_some() {
return Err(syn::Error::new_spanned(
param,
"only one lifetime parameter is supported",
));
}
lifetime = Some(param.lifetime.clone());
}
}
fn pst_lifetime<'a>(generics: &'a syn::Generics) -> syn::Result<syn::Lifetime> {
let generics = GenericsUtil::new(generics);
let lifetime = generics
.assert_at_most_one_lifetime_param()?
.map(|p| p.lifetime.clone());
Ok(match lifetime {
Some(lifetime) => lifetime,
None => syn::parse_quote_spanned! { Span::call_site() => 'pst },
Expand All @@ -149,13 +142,14 @@ fn pst_lifetime<'a>(

fn derive_provides_static_type_impl(input: proc_macro::TokenStream) -> syn::Result<syn::ItemImpl> {
let input: DeriveInput = syn::parse(input)?;
let input = DeriveInputUtil::new(&input)?;

let span = input.ident.span();

let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let lifetime = pst_lifetime(input.generics.params.iter())?;
let lifetime = pst_lifetime(&input.generics)?;

let mut lifetimes: Vec<syn::Lifetime> = Vec::new();
let mut static_lifetimes: Vec<syn::Lifetime> = Vec::new();
Expand Down
51 changes: 18 additions & 33 deletions starlark_derive/src/module/parse/fun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use syn::Attribute;
use syn::Expr;
use syn::FnArg;
use syn::GenericArgument;
use syn::GenericParam;
use syn::Generics;
use syn::ItemFn;
use syn::Lifetime;
Expand All @@ -48,6 +47,7 @@ use crate::module::typ::StarAttr;
use crate::module::typ::StarFun;
use crate::module::typ::StarFunSource;
use crate::module::typ::StarStmt;
use crate::util::GenericsUtil;

#[derive(Default)]
struct FnAttrs {
Expand Down Expand Up @@ -500,40 +500,25 @@ fn parse_fn_output(return_type: &ReturnType, span: Span, has_v: bool) -> syn::Re
}

fn parse_fn_generics(generics: &Generics) -> syn::Result<bool> {
let generics = GenericsUtil::new(generics);
let mut seen_v = false;
for param in &generics.params {
match param {
GenericParam::Type(..) => {
return Err(syn::Error::new(
param.span(),
"Function cannot have type parameters",
));
}
GenericParam::Const(..) => {
return Err(syn::Error::new(
param.span(),
"Function cannot have const parameters",
));
}
GenericParam::Lifetime(lifetime) => {
if lifetime.lifetime.ident != "v" {
return Err(syn::Error::new(
lifetime.lifetime.span(),
"Function cannot have lifetime parameters other than `v",
));
}
if !lifetime.bounds.is_empty() {
return Err(syn::Error::new(
lifetime.span(),
"Function lifetime params must not have bounds",
));
}
if seen_v {
return Err(syn::Error::new(lifetime.span(), "Duplicate `v parameters"));
}
seen_v = true;
}
for lifetime in generics.assert_only_lifetime_params()? {
if lifetime.lifetime.ident != "v" {
return Err(syn::Error::new(
lifetime.lifetime.span(),
"Function cannot have lifetime parameters other than `v",
));
}
if !lifetime.bounds.is_empty() {
return Err(syn::Error::new(
lifetime.span(),
"Function lifetime params must not have bounds",
));
}
if seen_v {
return Err(syn::Error::new(lifetime.span(), "Duplicate `v parameters"));
}
seen_v = true;
}
Ok(seen_v)
}
Expand Down
4 changes: 2 additions & 2 deletions starlark_derive/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use syn::Lifetime;
use syn::LifetimeParam;

pub fn derive_no_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut input = parse_macro_input!(input as DeriveInput);
let input = parse_macro_input!(input as DeriveInput);
let tick_v = GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'v", Span::call_site())));

let mut has_tick_v = false;
for param in &mut input.generics.params {
for param in &input.generics.params {
if let GenericParam::Lifetime(t) = param {
if t.lifetime.ident == "v" {
has_tick_v = true;
Expand Down
13 changes: 2 additions & 11 deletions starlark_derive/src/starlark_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use quote::quote_spanned;
use syn::spanned::Spanned;

use crate::util::GenericsUtil;
use crate::v_lifetime::find_v_lifetime;
use crate::vtable::vtable_has_field_name;

Expand Down Expand Up @@ -134,17 +135,7 @@ impl ImplStarlarkValue {
));
}

for param in &self.input.generics.params {
match param {
syn::GenericParam::Lifetime(_) => {}
_ => {
return Err(syn::Error::new_spanned(
param,
"only lifetime parameters are supported to implement `UnpackValue` or `StarlarkTypeRepr`",
));
}
}
}
GenericsUtil::new(&self.input.generics).assert_only_lifetime_params()?;

let lt = &self.lifetime_param;
let params = &self.input.generics.params;
Expand Down
12 changes: 4 additions & 8 deletions starlark_derive/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,13 @@ fn is_ignore(attrs: &[Attribute]) -> bool {
}

fn trace_impl(derive_input: &DeriveInput, generics: &Generics) -> syn::Result<syn::Expr> {
let derive_input = DeriveInputUtil::new(derive_input)?;

let generic_types = generics
.params
.iter()
.filter_map(|x| match x {
GenericParam::Type(x) => Some(x.ident.to_string()),
_ => None,
})
.type_params()
.map(|x| x.ident.to_string())
.collect();

let derive_input = DeriveInputUtil::new(derive_input)?;

derive_input.for_each_field(|name, field| {
if is_ignore(&field.attrs) {
Ok(quote! {})
Expand Down
74 changes: 74 additions & 0 deletions starlark_derive/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,78 @@ impl<'a> DeriveInputUtil<'a> {
})
})
}

pub(crate) fn generics(self) -> GenericsUtil<'a> {
match self {
DeriveInputUtil::Struct(data) => GenericsUtil::new(&data.derive_input.generics),
DeriveInputUtil::Enum(data) => GenericsUtil::new(&data.derive_input.generics),
}
}
}

#[derive(Copy, Clone)]
pub(crate) struct GenericsUtil<'a> {
pub(crate) generics: &'a syn::Generics,
}

impl<'a> GenericsUtil<'a> {
pub(crate) fn new(generics: &'a syn::Generics) -> Self {
GenericsUtil { generics }
}

pub(crate) fn assert_only_lifetime_params(self) -> syn::Result<Vec<&'a syn::LifetimeParam>> {
let mut lifetimes = Vec::new();
for param in &self.generics.params {
match param {
syn::GenericParam::Lifetime(param) => lifetimes.push(param),
_ => {
return Err(syn::Error::new_spanned(
param,
"only lifetime parameters are supported (no type or const parameters)",
));
}
}
}
Ok(lifetimes)
}

pub(crate) fn assert_only_type_params(self) -> syn::Result<Vec<&'a syn::TypeParam>> {
let mut type_params = Vec::new();
for param in &self.generics.params {
match param {
syn::GenericParam::Type(param) => type_params.push(param),
_ => {
return Err(syn::Error::new_spanned(
param,
"only type parameters are supported (no lifetime or const parameters)",
));
}
}
}
Ok(type_params)
}

pub(crate) fn assert_at_most_one_lifetime_param(
self,
) -> syn::Result<Option<&'a syn::LifetimeParam>> {
let mut lifetime_params = self.generics.lifetimes();
let Some(lt) = lifetime_params.next() else {
return Ok(None);
};
if lifetime_params.next().is_some() {
return Err(syn::Error::new_spanned(
lt,
"expecting at most one lifetime parameter",
));
}
Ok(Some(lt))
}
}

impl<'a> Deref for GenericsUtil<'a> {
type Target = syn::Generics;

fn deref(&self) -> &Self::Target {
self.generics
}
}
40 changes: 18 additions & 22 deletions starlark_derive/src/v_lifetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,25 @@
* limitations under the License.
*/

use crate::util::GenericsUtil;

/// Find at most one lifetime parameter, which must be named `'v`.
pub(crate) fn find_v_lifetime(generics: &syn::Generics) -> syn::Result<Option<&syn::Lifetime>> {
let mut found_lifetime = None;
for lifetime in generics.lifetimes() {
if found_lifetime.is_some() {
return Err(syn::Error::new_spanned(
lifetime,
"Only one lifetime parameter is allowed",
));
}
if !lifetime.bounds.is_empty() {
return Err(syn::Error::new_spanned(
lifetime,
"Lifetime parameter cannot have bounds",
));
}
if lifetime.lifetime.ident != "v" {
return Err(syn::Error::new_spanned(
lifetime,
"Lifetime parameter must be named 'v'",
));
}
found_lifetime = Some(&lifetime.lifetime);
let generics = GenericsUtil::new(generics);
let Some(lifetime) = generics.assert_at_most_one_lifetime_param()? else {
return Ok(None);
};
if !lifetime.bounds.is_empty() {
return Err(syn::Error::new_spanned(
lifetime,
"Lifetime parameter cannot have bounds",
));
}
if lifetime.lifetime.ident != "v" {
return Err(syn::Error::new_spanned(
lifetime,
"Lifetime parameter must be named 'v'",
));
}
Ok(found_lifetime)
Ok(Some(&lifetime.lifetime))
}
23 changes: 9 additions & 14 deletions starlark_derive/src/visit_span.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fn derive_body(input: &DeriveInput) -> syn::Result<syn::Expr> {
}

fn derive(input: DeriveInput) -> syn::Result<TokenStream> {
let input = DeriveInputUtil::new(&input)?;
let (_impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let name = &input.ident;
let body = derive_body(&input)?;
Expand All @@ -43,22 +44,16 @@ fn derive(input: DeriveInput) -> syn::Result<TokenStream> {
quote! {}
} else {
let params = input
.generics
.params
.iter()
.map(|p| match p {
syn::GenericParam::Type(t) => {
let t = &t.ident;
Ok(quote! {
#t: crate::eval::runtime::visit_span::VisitSpanMut
})
.generics()
.assert_only_type_params()?
.into_iter()
.map(|t| {
let t = &t.ident;
quote! {
#t: crate::eval::runtime::visit_span::VisitSpanMut
}
_ => Err(syn::Error::new_spanned(
p,
"VisitSpanMut cannot be derived for generics with non-type params",
)),
})
.collect::<syn::Result<Vec<_>>>()?;
.collect::<Vec<_>>();
quote! {
< #(#params,)* >
}
Expand Down

0 comments on commit e2dd0d6

Please sign in to comment.