Skip to content

Commit

Permalink
Handle bitflags bits method calls
Browse files Browse the repository at this point in the history
  • Loading branch information
glandium authored and emilio committed Sep 8, 2023
1 parent f72e447 commit 43af1eb
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -33,7 +33,7 @@ heck = "0.4"
[dependencies.syn]
version = "1.0.88"
default-features = false
features = ["clone-impls", "extra-traits", "full", "parsing", "printing"]
features = ["clone-impls", "extra-traits", "fold", "full", "parsing", "printing"]

[dev-dependencies]
serial_test = "0.5.0"
Expand Down
82 changes: 79 additions & 3 deletions src/bindgen/bitflags.rs
Expand Up @@ -3,6 +3,8 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */

use proc_macro2::TokenStream;
use std::collections::HashSet;
use syn::fold::Fold;
use syn::parse::{Parse, ParseStream, Parser, Result as ParseResult};

// $(#[$outer:meta])*
Expand Down Expand Up @@ -84,17 +86,86 @@ struct Flag {
semicolon_token: Token![;],
}

struct FlagValueFold<'a> {
struct_name: &'a syn::Ident,
flag_names: &'a HashSet<String>,
}

impl<'a> FlagValueFold<'a> {
fn is_self(&self, ident: &syn::Ident) -> bool {
ident == self.struct_name || ident == "Self"
}
}

impl<'a> Fold for FlagValueFold<'a> {
fn fold_expr(&mut self, node: syn::Expr) -> syn::Expr {
// bitflags 2 doesn't expose `bits` publically anymore, and the documented way to
// combine flags is using the `bits` method, e.g.
// ```
// bitflags! {
// struct Flags: u8 {
// const A = 1;
// const B = 1 << 1;
// const AB = Flags::A.bits() | Flags::B.bits();
// }
// }
// ```
// As we're transforming the struct definition into `struct StructName { bits: T }`
// as far as our bindings generation is concerned, `bits` is available as a field,
// so by replacing `StructName::FLAG.bits()` with `StructName::FLAG.bits`, we make
// e.g. `Flags::AB` available in the generated bindings.
match node {
syn::Expr::MethodCall(syn::ExprMethodCall {
attrs,
receiver,
dot_token,
method,
args,
..
}) if method == "bits"
&& args.is_empty()
&& matches!(&*receiver,
syn::Expr::Path(syn::ExprPath { path, .. })
if path.segments.len() == 2
&& self.is_self(&path.segments.first().unwrap().ident)
&& self
.flag_names
.contains(&path.segments.last().unwrap().ident.to_string())) =>
{
return syn::Expr::Field(syn::ExprField {
attrs,
base: receiver,
dot_token,
member: syn::Member::Named(method),
});
}
_ => {}
}
syn::fold::fold_expr(self, node)
}
}

impl Flag {
fn expand(&self, struct_name: &syn::Ident, repr: &syn::Type) -> TokenStream {
fn expand(
&self,
struct_name: &syn::Ident,
repr: &syn::Type,
flag_names: &HashSet<String>,
) -> TokenStream {
let Flag {
ref attrs,
ref name,
ref value,
..
} = *self;
let folded_value = FlagValueFold {
struct_name,
flag_names,
}
.fold_expr(value.clone());
quote! {
#(#attrs)*
pub const #name : #struct_name = #struct_name { bits: (#value) as #repr };
pub const #name : #struct_name = #struct_name { bits: (#folded_value) as #repr };
}
}
}
Expand Down Expand Up @@ -130,8 +201,13 @@ impl Parse for Flags {
impl Flags {
fn expand(&self, struct_name: &syn::Ident, repr: &syn::Type) -> TokenStream {
let mut ts = quote! {};
let flag_names = self
.0
.iter()
.map(|flag| flag.name.to_string())
.collect::<HashSet<_>>();
for flag in &self.0 {
ts.extend(flag.expand(struct_name, repr));
ts.extend(flag.expand(struct_name, repr, &flag_names));
}
ts
}
Expand Down
8 changes: 4 additions & 4 deletions tests/rust/bitflags.rs
Expand Up @@ -13,11 +13,11 @@ bitflags! {
const START = 1 << 1;
/// 'end'
const END = 1 << 2;
const ALIAS = Self::END.bits;
const ALIAS = Self::END.bits();
/// 'flex-start'
const FLEX_START = 1 << 3;
const MIXED = 1 << 4 | AlignFlags::FLEX_START.bits | AlignFlags::END.bits;
const MIXED_SELF = 1 << 5 | AlignFlags::FLEX_START.bits | AlignFlags::END.bits;
const MIXED = 1 << 4 | AlignFlags::FLEX_START.bits() | AlignFlags::END.bits();
const MIXED_SELF = 1 << 5 | AlignFlags::FLEX_START.bits() | AlignFlags::END.bits();
}
}

Expand All @@ -34,7 +34,7 @@ bitflags! {
pub struct LargeFlags: u64 {
/// Flag with a very large shift that usually would be narrowed.
const LARGE_SHIFT = 1u64 << 44;
const INVERTED = !Self::LARGE_SHIFT.bits;
const INVERTED = !Self::LARGE_SHIFT.bits();
}
}

Expand Down

0 comments on commit 43af1eb

Please sign in to comment.