Skip to content

Commit

Permalink
Merge pull request #2657 from mejrs/decorator_fix
Browse files Browse the repository at this point in the history
Update decorator to use Cell counter
  • Loading branch information
mejrs committed Oct 10, 2022
2 parents f68781e + d8fa6be commit c9b26f5
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 8 deletions.
26 changes: 19 additions & 7 deletions examples/decorator/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use std::cell::Cell;

/// A function decorator that keeps track how often it is called.
///
/// It otherwise doesn't do anything special.
#[pyclass(name = "Counter")]
pub struct PyCounter {
// We use `#[pyo3(get)]` so that python can read the count but not mutate it.
#[pyo3(get)]
count: u64,
// Keeps track of how many calls have gone through.
//
// See the discussion at the end for why `Cell` is used.
count: Cell<u64>,

// This is the actual function being wrapped.
wraps: Py<PyAny>,
Expand All @@ -23,20 +25,30 @@ impl PyCounter {
// 2. We still need to handle any exceptions that the function might raise
#[new]
fn __new__(wraps: Py<PyAny>) -> Self {
PyCounter { count: 0, wraps }
PyCounter {
count: Cell::new(0),
wraps,
}
}

#[getter]
fn count(&self) -> u64 {
self.count.get()
}

#[args(args = "*", kwargs = "**")]
fn __call__(
&mut self,
&self,
py: Python<'_>,
args: &PyTuple,
kwargs: Option<&PyDict>,
) -> PyResult<Py<PyAny>> {
self.count += 1;
let old_count = self.count.get();
let new_count = old_count + 1;
self.count.set(new_count);
let name = self.wraps.getattr(py, "__name__")?;

println!("{} has been called {} time(s).", name, self.count);
println!("{} has been called {} time(s).", name, new_count);

// After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?;
Expand Down
11 changes: 11 additions & 0 deletions examples/decorator/tests/test_.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,14 @@ def say_hello(name="default"):
say_hello()

assert say_hello.count == 4


# https://github.com/PyO3/pyo3/discussions/2598
def test_discussion_2598():
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")

say_hello()
say_hello()
51 changes: 50 additions & 1 deletion guide/src/class/call.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ say_hello has been called 4 time(s).
hello
```

#### Pure Python implementation
### Pure Python implementation

A Python implementation of this looks similar to the Rust version:

Expand Down Expand Up @@ -65,3 +65,52 @@ def Counter(wraps):
return wraps(*args, **kwargs)
return call
```

### What is the `Cell` for?

A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count:

```rust,ignore
#[args(args = "*", kwargs = "**")]
fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Py<PyAny>> {
self.count += 1;
let name = self.wraps.getattr(py, "__name__")?;

println!("{} has been called {} time(s).", name, self.count);

// After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?;

// We could do something with the return value of
// the function before returning it
Ok(ret)
}
```

The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively,
and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code
which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception.

As a result, something innocent like this will raise an exception:

```py
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")

say_hello()
# RuntimeError: Already borrowed
```

The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared counter and the easiest way to do that is to use [`Cell`], so that's what is used here.

This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above:
- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run.
- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods).
- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run.

This is especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things.

[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator <Help Wanted> · Discussion #2598 · PyO3/pyo3"
[`Cell`]: https://doc.rust-lang.org/std/cell/struct.Cell.html "Cell in std::cell - Rust"

0 comments on commit c9b26f5

Please sign in to comment.