Skip to content

Commit

Permalink
Merge pull request #2456 from ikrivosheev/feature/issues-2383_classat…
Browse files Browse the repository at this point in the history
…tribute

Allow #[classattr] take Python argument
  • Loading branch information
davidhewitt committed Jun 16, 2022
2 parents cbdd2e3 + f19561c commit 1978712
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `PyCode` and `PyFrame` high level objects. [#2408](https://github.com/PyO3/pyo3/pull/2408)
- Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423)
- Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428)
- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383)

### Changed

Expand Down
8 changes: 1 addition & 7 deletions pyo3-macros-backend/src/method.rs
Expand Up @@ -374,13 +374,7 @@ impl<'a> FnSpec<'a> {

let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
Some(MethodTypeAttribute::ClassAttribute) => {
ensure_spanned!(
sig.inputs.is_empty(),
sig.inputs.span() => "class attribute methods cannot take arguments"
);
(FnType::ClassAttribute, false, None)
}
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
Some(MethodTypeAttribute::New) => {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
Expand Down
24 changes: 19 additions & 5 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -178,7 +178,7 @@ pub fn gen_py_method(
// Class attributes go before protos so that class attributes can be used to set proto
// method to None.
(_, FnType::ClassAttribute) => {
GeneratedPyMethod::Method(impl_py_class_attribute(cls, spec))
GeneratedPyMethod::Method(impl_py_class_attribute(cls, spec)?)
}
(PyMethodKind::Proto(proto_kind), _) => {
ensure_no_forbidden_protocol_attributes(spec, &method.method_name)?;
Expand Down Expand Up @@ -348,12 +348,25 @@ fn impl_traverse_slot(cls: &syn::Type, spec: FnSpec<'_>) -> TokenStream {
}}
}

fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<TokenStream> {
let (py_arg, args) = split_off_python_arg(&spec.args);
ensure_spanned!(
args.is_empty(),
args[0].ty.span() => "#[classattr] can only have one argument (of type pyo3::Python)"
);

let name = &spec.name;
let fncall = if py_arg.is_some() {
quote!(#cls::#name(py))
} else {
quote!(#cls::#name())
};

let wrapper_ident = format_ident!("__pymethod_{}__", name);
let deprecations = &spec.deprecations;
let python_name = spec.null_terminated_python_name();
quote! {

let classattr = quote! {
_pyo3::class::PyMethodDefType::ClassAttribute({
_pyo3::class::PyClassAttributeDef::new(
#python_name,
Expand All @@ -363,7 +376,7 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
#[allow(non_snake_case)]
fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
#deprecations
let mut ret = #cls::#name();
let mut ret = #fncall;
if false {
use _pyo3::impl_::ghost::IntoPyResult;
ret.assert_into_py_result();
Expand All @@ -375,7 +388,8 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
})
)
})
}
};
Ok(classattr)
}

fn impl_call_setter(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<TokenStream> {
Expand Down
6 changes: 6 additions & 0 deletions tests/test_class_attributes.rs
Expand Up @@ -45,6 +45,11 @@ impl Foo {
fn a_foo() -> Foo {
Foo { x: 1 }
}

#[classattr]
fn a_foo_with_py(py: Python<'_>) -> Py<Foo> {
Py::new(py, Foo { x: 1 }).unwrap()
}
}

#[test]
Expand All @@ -57,6 +62,7 @@ fn class_attributes() {
py_assert!(py, foo_obj, "foo_obj.a == 5");
py_assert!(py, foo_obj, "foo_obj.B == 'bar'");
py_assert!(py, foo_obj, "foo_obj.a_foo.x == 1");
py_assert!(py, foo_obj, "foo_obj.a_foo_with_py.x == 1");
}

// Ignored because heap types are not immutable:
Expand Down
6 changes: 3 additions & 3 deletions tests/ui/invalid_pymethods.stderr
@@ -1,8 +1,8 @@
error: class attribute methods cannot take arguments
--> tests/ui/invalid_pymethods.rs:9:29
error: #[classattr] can only have one argument (of type pyo3::Python)
--> tests/ui/invalid_pymethods.rs:9:34
|
9 | fn class_attr_with_args(foo: i32) {}
| ^^^
| ^^^

error: `#[classattr]` does not take any arguments
--> tests/ui/invalid_pymethods.rs:14:5
Expand Down

0 comments on commit 1978712

Please sign in to comment.