diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b8f040b4a2..3b295aa3d41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add FFI definitions from `cpython/import.h`.[#1475](https://github.com/PyO3/pyo3/pull/1475) - Add tuple and unit struct support for `#[pyclass]` macro. [#1504](https://github.com/PyO3/pyo3/pull/1504) - Add FFI definition `PyDateTime_TimeZone_UTC`. [#1572](https://github.com/PyO3/pyo3/pull/1572) +- Add support for `#[pyclass(extends=Exception)]`. [#1591](https://github.com/PyO3/pyo3/pull/1591) ### Changed - Change `PyTimeAcces::get_fold()` to return a `bool` instead of a `u8`. [#1397](https://github.com/PyO3/pyo3/pull/1397) diff --git a/tests/test_inheritance.rs b/tests/test_inheritance.rs index 056e3d8ccdf..c9d1df2bf69 100644 --- a/tests/test_inheritance.rs +++ b/tests/test_inheritance.rs @@ -153,7 +153,8 @@ except Exception as e: #[cfg(not(Py_LIMITED_API))] mod inheriting_native_type { use super::*; - use pyo3::types::{PyDict, PySet}; + use pyo3::exceptions::PyException; + use pyo3::types::{IntoPyDict, PyDict, PySet}; #[pyclass(extends=PySet)] #[derive(Debug)] @@ -208,6 +209,49 @@ mod inheriting_native_type { r#"dict_sub[0] = 1; assert dict_sub[0] == 1; assert dict_sub._name == "Hello :)""# ); } + + #[pyclass(extends=PyException)] + struct CustomException { + #[pyo3(get)] + context: &'static str, + } + + #[pymethods] + impl CustomException { + #[new] + fn new() -> Self { + CustomException { + context: "Hello :)", + } + } + } + + #[test] + fn custom_exception() { + Python::with_gil(|py| { + let cls = py.get_type::(); + let dict = [("cls", cls)].into_py_dict(py); + let res = py.run( + "e = cls('hello'); assert str(e) == 'hello'; assert e.context == 'Hello :)'; raise e", + None, + Some(dict) + ); + let err = res.unwrap_err(); + assert!(err.matches(py, cls), "{}", err); + + // catching the exception in Python also works: + py_run!( + py, + cls, + r#" + try: + raise cls("foo") + except cls: + pass + "# + ) + }) + } } #[pyclass(subclass)]