Skip to content

Commit

Permalink
buffer: tidy up exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Apr 1, 2021
1 parent 4713b46 commit 71dd405
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- The `auto-initialize` feature is no longer enabled by default. [#1443](https://github.com/PyO3/pyo3/pull/1443)
- Change `PyCFunction::new()` and `PyCFunction::new_with_keywords()` to take `&'static str` arguments rather than implicitly copying (and leaking) them. [#1450](https://github.com/PyO3/pyo3/pull/1450)
- Deprecate `PyModule` methods `call`, `call0`, `call1` and `get`. [#1492](https://github.com/PyO3/pyo3/pull/1492)
- Add length information to `PyBufferError`s raised from `PyBuffer::copy_to_slice` and `PyBuffer::copy_from_slice`. [#1534](https://github.com/PyO3/pyo3/pull/1534)

### Removed
- Remove deprecated exception names `BaseException` etc. [#1426](https://github.com/PyO3/pyo3/pull/1426)
Expand Down
91 changes: 44 additions & 47 deletions src/buffer.rs
Expand Up @@ -17,8 +17,10 @@
// DEALINGS IN THE SOFTWARE.

//! `PyBuffer` implementation
use crate::err::{self, PyResult};
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
use crate::{
err, exceptions::PyBufferError, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, PyResult,
Python,
};
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw;
Expand Down Expand Up @@ -147,19 +149,6 @@ pub unsafe trait Element: Copy {
fn is_compatible_format(format: &CStr) -> bool;
}

fn validate(b: &ffi::Py_buffer) -> PyResult<()> {
// shape and stride information must be provided when we use PyBUF_FULL_RO
if b.shape.is_null() {
return Err(exceptions::PyBufferError::new_err("Shape is Null"));
}
if b.strides.is_null() {
return Err(exceptions::PyBufferError::new_err(
"PyBuffer: Strides is Null",
));
}
Ok(())
}

impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
fn extract(obj: &PyAny) -> PyResult<PyBuffer<T>> {
Self::get(obj)
Expand All @@ -169,25 +158,37 @@ impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
impl<T: Element> PyBuffer<T> {
/// Get the underlying buffer from the specified python object.
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
unsafe {
let mut buf = Box::pin(ffi::Py_buffer::new());
// TODO: use nightly API Box::new_uninit() once stable
let mut buf = Box::new(mem::MaybeUninit::uninit());
let buf: Box<ffi::Py_buffer> = unsafe {
err::error_on_minusone(
obj.py(),
ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO),
ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_FULL_RO),
)?;
validate(&buf)?;
let buf = PyBuffer(buf, PhantomData);
// Type Check
if mem::size_of::<T>() == buf.item_size()
&& (buf.0.buf as usize) % mem::align_of::<T>() == 0
&& T::is_compatible_format(buf.format())
{
Ok(buf)
} else {
Err(exceptions::PyBufferError::new_err(
"Incompatible type as buffer",
))
}
// Safety: buf is initialized by PyObject_GetBuffer.
// TODO: use nightly API Box::assume_init() once stable
mem::transmute(buf)
};
// Create PyBuffer immediately so that if validation checks fail, the PyBuffer::drop code
// will call PyBuffer_Release (thus avoiding any leaks).
let buf = PyBuffer(Pin::from(buf), PhantomData);

if buf.0.shape.is_null() {
Err(PyBufferError::new_err("shape is null"))
} else if buf.0.strides.is_null() {
Err(PyBufferError::new_err("strides is null"))
} else if mem::size_of::<T>() != buf.item_size() || !T::is_compatible_format(buf.format()) {
Err(PyBufferError::new_err(format!(
"buffer contents are not compatible with {}",
std::any::type_name::<T>()
)))
} else if buf.0.buf.align_offset(mem::align_of::<T>()) != 0 {
Err(PyBufferError::new_err(format!(
"buffer contents are insufficiently aligned for {}",
std::any::type_name::<T>()
)))
} else {
Ok(buf)
}
}

Expand Down Expand Up @@ -441,9 +442,11 @@ impl<T: Element> PyBuffer<T> {

fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
if mem::size_of_val(target) != self.len_bytes() {
return Err(exceptions::PyBufferError::new_err(
"Slice length does not match buffer length.",
));
return Err(PyBufferError::new_err(format!(
"slice to copy to (of length {}) does not match buffer length of {}",
target.len(),
self.item_count()
)));
}
unsafe {
err::error_on_minusone(
Expand Down Expand Up @@ -525,12 +528,13 @@ impl<T: Element> PyBuffer<T> {

fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> {
if self.readonly() {
return buffer_readonly_error();
}
if mem::size_of_val(source) != self.len_bytes() {
return Err(exceptions::PyBufferError::new_err(
"Slice length does not match buffer length.",
));
return Err(PyBufferError::new_err("cannot write to read-only buffer"));
} else if mem::size_of_val(source) != self.len_bytes() {
return Err(PyBufferError::new_err(format!(
"slice to copy from (of length {}) does not match buffer length of {}",
source.len(),
self.item_count()
)));
}
unsafe {
err::error_on_minusone(
Expand Down Expand Up @@ -562,13 +566,6 @@ impl<T: Element> PyBuffer<T> {
}
}

#[inline(always)]
fn buffer_readonly_error() -> PyResult<()> {
Err(exceptions::PyBufferError::new_err(
"Cannot write to read-only buffer.",
))
}

impl<T> Drop for PyBuffer<T> {
fn drop(&mut self) {
Python::with_gil(|_| unsafe { ffi::PyBuffer_Release(&mut *self.0) });
Expand Down

0 comments on commit 71dd405

Please sign in to comment.