Skip to content

Commit

Permalink
Implement a safe API wrapping PyEval_SetProfile
Browse files Browse the repository at this point in the history
Fixes PyO3#4008.
  • Loading branch information
LilyFoote committed Apr 3, 2024
1 parent a4aea23 commit f449eb1
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/instrumentation.rs
@@ -0,0 +1,98 @@
use crate::ffi;
use crate::pyclass::boolean_struct::False;
use crate::types::PyFrame;
use crate::{Bound, PyAny, PyClass, PyObject, PyRefMut, PyResult, Python};
use std::ffi::c_int;

pub trait Event<'py>: Sized {
fn from_raw(what: c_int, arg: Option<Bound<'py, PyAny>>) -> PyResult<Self>;
}

pub enum ProfileEvent<'py> {
Call,
Return(Option<Bound<'py, PyAny>>),
CCall(Bound<'py, PyAny>),
CException(Bound<'py, PyAny>),
CReturn(Bound<'py, PyAny>),
}

impl<'py> Event<'py> for ProfileEvent<'py> {
fn from_raw(what: c_int, arg: Option<Bound<'py, PyAny>>) -> PyResult<ProfileEvent<'py>> {
let event = match what {
ffi::PyTrace_CALL => ProfileEvent::Call,
ffi::PyTrace_RETURN => ProfileEvent::Return(arg),
ffi::PyTrace_C_CALL => ProfileEvent::CCall(arg.unwrap()),
ffi::PyTrace_C_EXCEPTION => ProfileEvent::CException(arg.unwrap()),
ffi::PyTrace_C_RETURN => ProfileEvent::CReturn(arg.unwrap()),
_ => unreachable!(),
};
Ok(event)
}
}

pub trait Profiler: PyClass<Frozen = False> {
fn profile<'py>(
&mut self,
frame: Bound<'py, PyFrame>,
event: ProfileEvent<'py>,
) -> PyResult<()>;
}

pub fn register_profiler<P: Profiler>(profiler: Bound<'_, P>) {
unsafe { ffi::PyEval_SetProfile(Some(profile_callback::<P>), profiler.into_ptr()) };
}

extern "C" fn profile_callback<P>(
obj: *mut ffi::PyObject,
frame: *mut ffi::PyFrameObject,
what: c_int,
arg: *mut ffi::PyObject,
) -> c_int
where
P: Profiler,
{
// Safety:
//
// `frame` is an `ffi::PyFrameObject` which can be converted safely to a `PyObject`.
let frame = frame as *mut ffi::PyObject;
Python::with_gil(|py| {
// Safety:
//
// `obj` is a reference to our `Profiler` wrapped up in a Python object, so
// we can safely convert it from an `ffi::PyObject` to a `PyObject`.
//
// We borrow the object so we don't break reference counting.
//
// https://docs.python.org/3/c-api/init.html#c.Py_tracefunc
let obj = unsafe { PyObject::from_borrowed_ptr(py, obj) };
let mut profiler = obj.extract::<PyRefMut<'_, P>>(py).unwrap();

// Safety:
//
// We borrow the object so we don't break reference counting.
//
// https://docs.python.org/3/c-api/init.html#c.Py_tracefunc
let frame = unsafe { PyObject::from_borrowed_ptr(py, frame) };
let frame = frame.extract(py).unwrap();

// Safety:
//
// `arg` is either a `Py_None` (PyTrace_CALL) or any PyObject (PyTrace_RETURN) or
// NULL (PyTrace_RETURN).
//
// We borrow the object so we don't break reference counting.
//
// https://docs.python.org/3/c-api/init.html#c.Py_tracefunc
let arg = unsafe { Bound::from_borrowed_ptr_or_opt(py, arg) };

let event = ProfileEvent::from_raw(what, arg).unwrap();

match profiler.profile(frame, event) {
Ok(_) => 0,
Err(err) => {
err.restore(py);
-1
}
}
})
}
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -447,6 +447,8 @@ mod gil;
#[doc(hidden)]
pub mod impl_;
mod instance;
//#[cfg(feature = "instrumentation")]
pub mod instrumentation;
pub mod marker;
pub mod marshal;
#[macro_use]
Expand Down
51 changes: 51 additions & 0 deletions tests/test_instrumentation.rs
@@ -0,0 +1,51 @@
use pyo3::instrumentation::{register_profiler, ProfileEvent, Profiler};
use pyo3::prelude::*;
use pyo3::pyclass;
use pyo3::types::{PyFrame, PyList};

#[pyclass]
struct BasicProfiler {
events: Py<PyList>,
}

impl Profiler for BasicProfiler {
fn profile(&mut self, frame: Bound<'_, PyFrame>, event: ProfileEvent<'_>) -> PyResult<()> {
let py = frame.py();
let events = self.events.bind(py);
match event {
ProfileEvent::Call => events.append("call")?,
ProfileEvent::Return(_) => events.append("return")?,
_ => {}
};
Ok(())
}
}

const PYTHON_CODE: &str = r#"
def foo():
return "foo"
foo()
"#;

#[test]
fn test_profiler() {
Python::with_gil(|py| {
let events = PyList::empty_bound(py);
let profiler = Bound::new(
py,
BasicProfiler {
events: events.clone().into(),
},
)
.unwrap();
register_profiler(profiler);

py.run_bound(PYTHON_CODE, None, None).unwrap();

assert_eq!(
events.extract::<Vec<String>>().unwrap(),
vec!["call", "call", "return", "return"]
);
})
}

0 comments on commit f449eb1

Please sign in to comment.