Skip to content

Commit

Permalink
[review] kngwyu
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed May 7, 2021
1 parent 4d46abd commit 48e9881
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 27 deletions.
8 changes: 5 additions & 3 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ impl Parse for NameAttribute {
}
}

pub fn get_pyo3_attribute<T: Parse>(attr: &syn::Attribute) -> Result<Option<Punctuated<T, Comma>>> {
if attribute_ident_is(attr, "pyo3") {
pub fn get_pyo3_attributes<T: Parse>(
attr: &syn::Attribute,
) -> Result<Option<Punctuated<T, Comma>>> {
if is_attribute_ident(attr, "pyo3") {
attr.parse_args_with(Punctuated::parse_terminated).map(Some)
} else {
Ok(None)
}
}

pub fn attribute_ident_is(attr: &syn::Attribute, name: &str) -> bool {
pub fn is_attribute_ident(attr: &syn::Attribute, name: &str) -> bool {
if let Some(path_segment) = attr.path.segments.last() {
attr.path.segments.len() == 1 && path_segment.ident == name
} else {
Expand Down
8 changes: 5 additions & 3 deletions pyo3-macros-backend/src/from_pyobject.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::attributes::{self, get_pyo3_attribute, FromPyWithAttribute};
use crate::attributes::{self, get_pyo3_attributes, FromPyWithAttribute};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
Expand Down Expand Up @@ -252,7 +252,9 @@ impl<'a> Container<'a> {
}

struct ContainerOptions {
/// Treat the Container as a Wrapper, directly extract its fields from the input object.
transparent: bool,
/// Change the name of an enum variant in the generated error message.
annotation: Option<syn::LitStr>,
}

Expand Down Expand Up @@ -288,7 +290,7 @@ impl ContainerOptions {
annotation: None,
};
for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_attribute(attr)? {
if let Some(pyo3_attrs) = get_pyo3_attributes(attr)? {
for pyo3_attr in pyo3_attrs {
match pyo3_attr {
ContainerPyO3Attribute::Transparent(kw) => {
Expand Down Expand Up @@ -386,7 +388,7 @@ impl FieldPyO3Attributes {
let mut from_py_with = None;

for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_attribute(attr)? {
if let Some(pyo3_attrs) = get_pyo3_attributes(attr)? {
for pyo3_attr in pyo3_attrs {
match pyo3_attr {
FieldPyO3Attribute::Getter(field_getter) => {
Expand Down
6 changes: 3 additions & 3 deletions pyo3-macros-backend/src/konst.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::attributes::{
self, attribute_ident_is, get_deprecated_name_attribute, get_pyo3_attribute, take_attributes,
self, get_deprecated_name_attribute, get_pyo3_attributes, is_attribute_ident, take_attributes,
NameAttribute,
};
use crate::utils;
Expand Down Expand Up @@ -62,14 +62,14 @@ impl ConstAttributes {
};

take_attributes(attrs, |attr| {
if attribute_ident_is(attr, "classattr") {
if is_attribute_ident(attr, "classattr") {
ensure_spanned!(
attr.tokens.is_empty(),
attr.span() => "`#[classattr]` does not take any arguments"
);
attributes.is_class_attr = true;
Ok(true)
} else if let Some(pyo3_attributes) = get_pyo3_attribute(attr)? {
} else if let Some(pyo3_attributes) = get_pyo3_attributes(attr)? {
for pyo3_attr in pyo3_attributes {
match pyo3_attr {
PyO3ConstAttribute::Name(name) => attributes.set_name(name)?,
Expand Down
4 changes: 2 additions & 2 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
//! Code generation for the function that initializes a python module and adds classes and function.

use crate::attributes::{attribute_ident_is, take_attributes, NameAttribute};
use crate::attributes::{is_attribute_ident, take_attributes, NameAttribute};
use crate::pyfunction::{impl_wrap_pyfunction, PyFunctionOptions};
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down Expand Up @@ -80,7 +80,7 @@ fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs
let mut pyfn_args: Option<PyFnArgs> = None;

take_attributes(attrs, |attr| {
if attribute_ident_is(attr, "pyfn") {
if is_attribute_ident(attr, "pyfn") {
ensure_spanned!(
pyfn_args.is_none(),
attr.span() => "`#[pyfn] may only be specified once"
Expand Down
34 changes: 18 additions & 16 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::{
attributes::{
self, get_deprecated_name_attribute, get_pyo3_attribute, take_attributes,
self, get_deprecated_name_attribute, get_pyo3_attributes, take_attributes,
FromPyWithAttribute, NameAttribute,
},
method::{self, FnArg, FnSpec},
Expand Down Expand Up @@ -61,7 +61,7 @@ impl PyFunctionArgPyO3Attributes {
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
take_attributes(attrs, |attr| {
if let Some(pyo3_attrs) = get_pyo3_attribute(attr)? {
if let Some(pyo3_attrs) = get_pyo3_attributes(attr)? {
for attr in pyo3_attrs {
match attr {
PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
Expand Down Expand Up @@ -275,7 +275,7 @@ impl PyFunctionOptions {

pub fn take_pyo3_attributes(&mut self, attrs: &mut Vec<syn::Attribute>) -> syn::Result<()> {
take_attributes(attrs, |attr| {
if let Some(pyo3_attributes) = get_pyo3_attribute(attr)? {
if let Some(pyo3_attributes) = get_pyo3_attributes(attr)? {
self.add_attributes(pyo3_attributes)?;
Ok(true)
} else if let Some(name) = get_deprecated_name_attribute(attr)? {
Expand Down Expand Up @@ -425,20 +425,22 @@ fn function_c_wrapper(
pass_module: bool,
) -> Result<TokenStream> {
let names: Vec<Ident> = get_arg_names(&spec);
let cb;
let slf_module;
if pass_module {
cb = quote! {
pyo3::callback::convert(_py, #name(_slf, #(#names),*))
};
slf_module = Some(quote! {
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
});
let (cb, slf_module) = if pass_module {
(
quote! {
pyo3::callback::convert(_py, #name(_slf, #(#names),*))
},
Some(quote! {
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
}),
)
} else {
cb = quote! {
pyo3::callback::convert(_py, #name(#(#names),*))
};
slf_module = None;
(
quote! {
pyo3::callback::convert(_py, #name(#(#names),*))
},
None,
)
};
let py = syn::Ident::new("_py", Span::call_site());
let body = impl_arg_params(spec, None, cb, &py)?;
Expand Down

0 comments on commit 48e9881

Please sign in to comment.