Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow #[classattr] methods to be fallible #2385

Merged
merged 1 commit into from May 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Deprecate `ToBorrowedObject` trait (it is only used as a wrapper for `ToPyObject`). [#2333](https://github.com/PyO3/pyo3/pull/2333)
- `impl<T, const N: usize> IntoPy<PyObject> for [T; N]` now requires `T: IntoPy` rather than `T: ToPyObject`. [#2326](https://github.com/PyO3/pyo3/pull/2326)
- Correct `wrap_pymodule` to match normal namespacing rules: it no longer "sees through" glob imports of `use submodule::*` when `submodule::submodule` is a `#[pymodule]`. [#2363](https://github.com/PyO3/pyo3/pull/2363)
- Allow `#[classattr]` methods to be fallible. [#2385](https://github.com/PyO3/pyo3/pull/2385)

### Fixed

Expand Down
6 changes: 4 additions & 2 deletions guide/src/class.md
Expand Up @@ -583,8 +583,7 @@ impl MyClass {
## Class attributes

To create a class attribute (also called [class variable][classattr]), a method without
any arguments can be annotated with the `#[classattr]` attribute. The return type must be `T` for
some `T` that implements `IntoPy<PyObject>`.
any arguments can be annotated with the `#[classattr]` attribute.

```rust
# use pyo3::prelude::*;
Expand All @@ -604,6 +603,9 @@ Python::with_gil(|py| {
});
```

> Note: if the method has a `Result` return type and returns an `Err`, PyO3 will panic during
class creation.

If the class attribute is defined with `const` code only, one can also annotate associated
constants:

Expand Down
4 changes: 2 additions & 2 deletions pyo3-macros-backend/src/pyimpl.rs
Expand Up @@ -171,9 +171,9 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
_pyo3::class::PyClassAttributeDef::new(
#python_name,
_pyo3::impl_::pymethods::PyClassAttributeFactory({
fn __wrap(py: _pyo3::Python<'_>) -> _pyo3::PyObject {
fn __wrap(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
#deprecations
_pyo3::IntoPy::into_py(#cls::#member, py)
::std::result::Result::Ok(_pyo3::IntoPy::into_py(#cls::#member, py))
}
__wrap
})
Expand Down
9 changes: 7 additions & 2 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -349,9 +349,14 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
_pyo3::class::PyClassAttributeDef::new(
#python_name,
_pyo3::impl_::pymethods::PyClassAttributeFactory({
fn __wrap(py: _pyo3::Python<'_>) -> _pyo3::PyObject {
fn __wrap(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
#deprecations
_pyo3::IntoPy::into_py(#cls::#name(), py)
let mut ret = #cls::#name();
if false {
use _pyo3::impl_::ghost::IntoPyResult;
ret.assert_into_py_result();
}
_pyo3::callback::convert(py, ret)
}
__wrap
})
Expand Down
2 changes: 1 addition & 1 deletion src/impl_/pymethods.rs
Expand Up @@ -76,7 +76,7 @@ pub struct PyGetter(pub ffi::getter);
#[derive(Clone, Copy, Debug)]
pub struct PySetter(pub ffi::setter);
#[derive(Clone, Copy)]
pub struct PyClassAttributeFactory(pub for<'p> fn(Python<'p>) -> PyObject);
pub struct PyClassAttributeFactory(pub for<'p> fn(Python<'p>) -> PyResult<PyObject>);

// TODO: it would be nice to use CStr in these types, but then the constructors can't be const fn
// until `CStr::from_bytes_with_nul_unchecked` is const fn.
Expand Down
36 changes: 29 additions & 7 deletions src/type_object.rs
Expand Up @@ -133,8 +133,8 @@ impl LazyStaticType {
return;
}

let thread_id = thread::current().id();
{
let thread_id = thread::current().id();
let mut threads = self.initializing_threads.lock();
if threads.contains(&thread_id) {
// Reentrant call: just return the type object, even if the
Expand All @@ -144,26 +144,47 @@ impl LazyStaticType {
threads.push(thread_id);
}

struct InitializationGuard<'a> {
initializing_threads: &'a Mutex<Vec<ThreadId>>,
thread_id: ThreadId,
}
impl Drop for InitializationGuard<'_> {
fn drop(&mut self) {
let mut threads = self.initializing_threads.lock();
threads.retain(|id| *id != self.thread_id);
}
}

let guard = InitializationGuard {
initializing_threads: &self.initializing_threads,
thread_id,
};

// Pre-compute the class attribute objects: this can temporarily
// release the GIL since we're calling into arbitrary user code. It
// means that another thread can continue the initialization in the
// meantime: at worst, we'll just make a useless computation.
let mut items = vec![];
for_all_items(&mut |class_items| {
items.extend(class_items.methods.iter().filter_map(|def| {
for def in class_items.methods {
if let PyMethodDefType::ClassAttribute(attr) = def {
let key = extract_cstr_or_leak_cstring(
attr.name,
"class attribute name cannot contain nul bytes",
)
.unwrap();

let val = (attr.meth.0)(py);
Some((key, val))
} else {
None
match (attr.meth.0)(py) {
Ok(val) => items.push((key, val)),
Err(e) => panic!(
"An error occurred while initializing `{}.{}`: {}",
name,
attr.name.trim_end_matches('\0'),
e
),
}
}
}));
}
});

// Now we hold the GIL and we can assume it won't be released until we
Expand All @@ -173,6 +194,7 @@ impl LazyStaticType {

// Initialization successfully complete, can clear the thread list.
// (No further calls to get_or_init() will try to init, on any thread.)
std::mem::forget(guard);
*self.initializing_threads.lock() = Vec::new();
result
});
Expand Down
24 changes: 23 additions & 1 deletion tests/test_class_attributes.rs
@@ -1,6 +1,6 @@
#![cfg(feature = "macros")]

use pyo3::prelude::*;
use pyo3::{exceptions::PyValueError, prelude::*};

mod common;

Expand Down Expand Up @@ -89,3 +89,25 @@ fn recursive_class_attributes() {
py_assert!(py, foo_obj, "foo_obj.bar.x == 2");
py_assert!(py, bar_obj, "bar_obj.a_foo.x == 3");
}

#[test]
#[should_panic(
expected = "An error occurred while initializing `BrokenClass.fails_to_init`: \
ValueError: failed to create class attribute"
)]
fn test_fallible_class_attribute() {
#[pyclass]
struct BrokenClass;

#[pymethods]
impl BrokenClass {
#[classattr]
fn fails_to_init() -> PyResult<i32> {
Err(PyValueError::new_err("failed to create class attribute"))
}
}

Python::with_gil(|py| {
py.get_type::<BrokenClass>();
})
}