Skip to content

Commit

Permalink
pymethods: support gc protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Feb 15, 2022
1 parent 7851e86 commit 676295b
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 121 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)
- Add `PyAny::contains` method (`in` operator for `PyAny`). [#2115](https://github.com/PyO3/pyo3/pull/2115)
- Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133)
- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159)

### Changed

Expand Down
3 changes: 2 additions & 1 deletion guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ given signatures should be interpreted as follows:

#### Garbage Collector Integration

TODO; see [#1884](https://github.com/PyO3/pyo3/issues/1884)
- `__traverse__(<self>, visit: pyo3::class::gc::PyVisit) -> Result<(), pyo3::class::gc::PyTraverseError>`
- `__clear__(<self>) -> ()`

### `#[pyproto]` traits

Expand Down
21 changes: 0 additions & 21 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ impl<'a> PyClassImplsBuilder<'a> {
self.impl_into_py(),
self.impl_pyclassimpl(),
self.impl_freelist(),
self.impl_gc(),
]
.into_iter()
.collect()
Expand Down Expand Up @@ -981,26 +980,6 @@ impl<'a> PyClassImplsBuilder<'a> {
Vec::new()
}
}

/// Enforce at compile time that PyGCProtocol is implemented
fn impl_gc(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
if attr.is_gc {
let closure_name = format!("__assertion_closure_{}", cls);
let closure_token = syn::Ident::new(&closure_name, Span::call_site());
quote! {
fn #closure_token() {
use _pyo3::class;

fn _assert_implements_protocol<'p, T: _pyo3::class::PyGCProtocol<'p>>() {}
_assert_implements_protocol::<#cls>();
}
}
} else {
quote! {}
}
}
}

fn define_inventory_class(inventory_class_name: &syn::Ident) -> TokenStream {
Expand Down
208 changes: 117 additions & 91 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,94 @@ enum PyMethodKind {

impl PyMethodKind {
fn from_name(name: &str) -> Self {
if let Some(slot_def) = pyproto(name) {
PyMethodKind::Proto(PyMethodProtoKind::Slot(slot_def))
} else if name == "__call__" {
PyMethodKind::Proto(PyMethodProtoKind::Call)
} else if let Some(slot_fragment_def) = pyproto_fragment(name) {
PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(slot_fragment_def))
} else {
PyMethodKind::Fn
match name {
// Protocol implemented through slots
"__getattr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETATTR__)),
"__str__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__STR__)),
"__repr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPR__)),
"__hash__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__HASH__)),
"__richcmp__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RICHCMP__)),
"__get__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GET__)),
"__iter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITER__)),
"__next__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEXT__)),
"__await__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AWAIT__)),
"__aiter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AITER__)),
"__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)),
"__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)),
"__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)),
"__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)),
"__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)),
"__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)),
"__abs__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ABS__)),
"__invert__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INVERT__)),
"__index__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INDEX__)),
"__int__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INT__)),
"__float__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__FLOAT__)),
"__bool__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__BOOL__)),
"__iadd__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IADD__)),
"__isub__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ISUB__)),
"__imul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMUL__)),
"__imatmul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMATMUL__)),
"__itruediv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITRUEDIV__)),
"__ifloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IFLOORDIV__)),
"__imod__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMOD__)),
"__ipow__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IPOW__)),
"__ilshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ILSHIFT__)),
"__irshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IRSHIFT__)),
"__iand__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IAND__)),
"__ixor__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IXOR__)),
"__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)),
"__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)),
"__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)),
// Protocols implemented through traits
"__setattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETATTR__)),
"__delattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELATTR__)),
"__set__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SET__)),
"__delete__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELETE__)),
"__setitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETITEM__)),
"__delitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELITEM__)),
"__add__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ADD__)),
"__radd__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RADD__)),
"__sub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SUB__)),
"__rsub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSUB__)),
"__mul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MUL__)),
"__rmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMUL__)),
"__matmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MATMUL__)),
"__rmatmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMATMUL__)),
"__floordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__FLOORDIV__)),
"__rfloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RFLOORDIV__)),
"__truediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__TRUEDIV__)),
"__rtruediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RTRUEDIV__)),
"__divmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DIVMOD__)),
"__rdivmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RDIVMOD__)),
"__mod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MOD__)),
"__rmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMOD__)),
"__lshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LSHIFT__)),
"__rlshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RLSHIFT__)),
"__rshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSHIFT__)),
"__rrshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RRSHIFT__)),
"__and__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__AND__)),
"__rand__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RAND__)),
"__xor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__XOR__)),
"__rxor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RXOR__)),
"__or__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__OR__)),
"__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)),
"__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)),
"__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)),
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
// Not a proto
_ => PyMethodKind::Fn,
}
}
}

enum PyMethodProtoKind {
Slot(&'static SlotDef),
Call,
Traverse,
SlotFragment(&'static SlotFragmentDef),
}

Expand Down Expand Up @@ -108,6 +181,9 @@ pub fn gen_py_method(
PyMethodProtoKind::Call => {
GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?)
}
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, method.spec)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?;
GeneratedPyMethod::SlotTraitImpl(method.method_name, proto)
Expand Down Expand Up @@ -220,6 +296,36 @@ fn impl_call_slot(cls: &syn::Type, mut spec: FnSpec) -> Result<TokenStream> {
}})
}

fn impl_traverse_slot(cls: &syn::Type, spec: FnSpec) -> Result<TokenStream> {
let ident = spec.name;
Ok(quote! {{
pub unsafe extern "C" fn __wrap_(
slf: *mut _pyo3::ffi::PyObject,
visit: _pyo3::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int
{
let pool = _pyo3::GILPool::new();
let py = pool.python();
_pyo3::callback::abort_on_traverse_panic(::std::panic::catch_unwind(move || {
let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf);

let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py);
let borrow = slf.try_borrow();
if let ::std::result::Result::Ok(borrow) = borrow {
_pyo3::class::gc::unwrap_traverse_result(borrow.#ident(visit))
} else {
0
}
}))
}
_pyo3::ffi::PyType_Slot {
slot: _pyo3::ffi::Py_tp_traverse,
pfunc: __wrap_ as _pyo3::ffi::traverseproc as _
}
}})
}

fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec) -> TokenStream {
let name = &spec.name;
let deprecations = &spec.deprecations;
Expand Down Expand Up @@ -567,49 +673,9 @@ const __RELEASEBUFFER__: SlotDef = SlotDef::new("Py_bf_releasebuffer", "releaseb
.arguments(&[Ty::PyBuffer])
.ret_ty(Ty::Void)
.require_unsafe();

fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
match method_name {
"__getattr__" => Some(&__GETATTR__),
"__str__" => Some(&__STR__),
"__repr__" => Some(&__REPR__),
"__hash__" => Some(&__HASH__),
"__richcmp__" => Some(&__RICHCMP__),
"__get__" => Some(&__GET__),
"__iter__" => Some(&__ITER__),
"__next__" => Some(&__NEXT__),
"__await__" => Some(&__AWAIT__),
"__aiter__" => Some(&__AITER__),
"__anext__" => Some(&__ANEXT__),
"__len__" => Some(&__LEN__),
"__contains__" => Some(&__CONTAINS__),
"__getitem__" => Some(&__GETITEM__),
"__pos__" => Some(&__POS__),
"__neg__" => Some(&__NEG__),
"__abs__" => Some(&__ABS__),
"__invert__" => Some(&__INVERT__),
"__index__" => Some(&__INDEX__),
"__int__" => Some(&__INT__),
"__float__" => Some(&__FLOAT__),
"__bool__" => Some(&__BOOL__),
"__iadd__" => Some(&__IADD__),
"__isub__" => Some(&__ISUB__),
"__imul__" => Some(&__IMUL__),
"__imatmul__" => Some(&__IMATMUL__),
"__itruediv__" => Some(&__ITRUEDIV__),
"__ifloordiv__" => Some(&__IFLOORDIV__),
"__imod__" => Some(&__IMOD__),
"__ipow__" => Some(&__IPOW__),
"__ilshift__" => Some(&__ILSHIFT__),
"__irshift__" => Some(&__IRSHIFT__),
"__iand__" => Some(&__IAND__),
"__ixor__" => Some(&__IXOR__),
"__ior__" => Some(&__IOR__),
"__getbuffer__" => Some(&__GETBUFFER__),
"__releasebuffer__" => Some(&__RELEASEBUFFER__),
_ => None,
}
}
const __CLEAR__: SlotDef = SlotDef::new("Py_tp_clear", "inquiry")
.arguments(&[])
.ret_ty(Ty::Int);

#[derive(Clone, Copy)]
enum Ty {
Expand Down Expand Up @@ -1045,46 +1111,6 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object,
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
match method_name {
"__setattr__" => Some(&__SETATTR__),
"__delattr__" => Some(&__DELATTR__),
"__set__" => Some(&__SET__),
"__delete__" => Some(&__DELETE__),
"__setitem__" => Some(&__SETITEM__),
"__delitem__" => Some(&__DELITEM__),
"__add__" => Some(&__ADD__),
"__radd__" => Some(&__RADD__),
"__sub__" => Some(&__SUB__),
"__rsub__" => Some(&__RSUB__),
"__mul__" => Some(&__MUL__),
"__rmul__" => Some(&__RMUL__),
"__matmul__" => Some(&__MATMUL__),
"__rmatmul__" => Some(&__RMATMUL__),
"__floordiv__" => Some(&__FLOORDIV__),
"__rfloordiv__" => Some(&__RFLOORDIV__),
"__truediv__" => Some(&__TRUEDIV__),
"__rtruediv__" => Some(&__RTRUEDIV__),
"__divmod__" => Some(&__DIVMOD__),
"__rdivmod__" => Some(&__RDIVMOD__),
"__mod__" => Some(&__MOD__),
"__rmod__" => Some(&__RMOD__),
"__lshift__" => Some(&__LSHIFT__),
"__rlshift__" => Some(&__RLSHIFT__),
"__rshift__" => Some(&__RSHIFT__),
"__rrshift__" => Some(&__RRSHIFT__),
"__and__" => Some(&__AND__),
"__rand__" => Some(&__RAND__),
"__xor__" => Some(&__XOR__),
"__rxor__" => Some(&__RXOR__),
"__or__" => Some(&__OR__),
"__ror__" => Some(&__ROR__),
"__pow__" => Some(&__POW__),
"__rpow__" => Some(&__RPOW__),
_ => None,
}
}

fn extract_proto_arguments(
cls: &syn::Type,
py: &syn::Ident,
Expand Down
15 changes: 15 additions & 0 deletions src/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,18 @@ where
R::ERR_VALUE
})
}

/// Aborts if panic has occurred. Used inside `__traverse__` implementations, where panicking is not possible.
#[doc(hidden)]
#[inline]
pub fn abort_on_traverse_panic(
panic_result: Result<c_int, Box<dyn Any + Send + 'static>>,
) -> c_int {
match panic_result {
Ok(traverse_result) => traverse_result,
Err(_payload) => {
eprintln!("FATAL: panic inside __traverse__ handler; aborting.");
::std::process::abort()
}
}
}
16 changes: 16 additions & 0 deletions src/class/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,20 @@ impl<'p> PyVisit<'p> {
Err(PyTraverseError(r))
}
}

/// Creates the PyVisit from the arguments to tp_traverse
#[doc(hidden)]
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, _py: Python<'p>) -> Self {
Self { visit, arg, _py }
}
}

/// Unwraps the result of __traverse__ for tp_traverse
#[doc(hidden)]
#[inline]
pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int {
match result {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
}
}
14 changes: 6 additions & 8 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![cfg(feature = "macros")]
#![cfg(feature = "pyproto")] // FIXME: #[pymethods] to support gc protocol

use pyo3::class::PyGCProtocol;
use pyo3::class::PyTraverseError;
use pyo3::class::PyVisit;
use pyo3::prelude::*;
Expand Down Expand Up @@ -90,8 +88,8 @@ struct GcIntegration {
dropped: TestDropCall,
}

#[pyproto]
impl PyGCProtocol for GcIntegration {
#[pymethods]
impl GcIntegration {
fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
visit.call(&self.self_ref)
}
Expand Down Expand Up @@ -133,8 +131,8 @@ fn gc_integration() {
#[pyclass(gc)]
struct GcIntegration2 {}

#[pyproto]
impl PyGCProtocol for GcIntegration2 {
#[pymethods]
impl GcIntegration2 {
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
Ok(())
}
Expand Down Expand Up @@ -230,8 +228,8 @@ impl TraversableClass {
}
}

#[pyproto]
impl PyGCProtocol for TraversableClass {
#[pymethods]
impl TraversableClass {
fn __clear__(&mut self) {}
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
self.traversed.store(true, Ordering::Relaxed);
Expand Down

0 comments on commit 676295b

Please sign in to comment.