diff --git a/Cargo.toml b/Cargo.toml index 11bbac89e..ddfd8e4f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,5 @@ license-file = "LICENSE" libc = "0.2" num-complex = "0.1" ndarray = "0.11" -pyo3 = "0.3.1" +pyo3 = { git = "https://github.com/pyo3/pyo3", rev = "547fa35604a61d2d750390940ed5a0d34ed09821"} + diff --git a/example/extensions/Cargo.toml b/example/extensions/Cargo.toml index a84fbf8fe..7010b1796 100644 --- a/example/extensions/Cargo.toml +++ b/example/extensions/Cargo.toml @@ -9,5 +9,5 @@ crate-type = ["cdylib"] [dependencies] numpy = { path = "../.." } +pyo3 = { git = "https://github.com/pyo3/pyo3", rev = "547fa35604a61d2d750390940ed5a0d34ed09821"} ndarray = "0.11" -pyo3 = "0.3.1" diff --git a/example/extensions/src/lib.rs b/example/extensions/src/lib.rs index 81b2639b2..36ada76c2 100644 --- a/example/extensions/src/lib.rs +++ b/example/extensions/src/lib.rs @@ -27,7 +27,7 @@ fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> { // wrapper of `axpy` #[pyfn(m, "axpy")] - fn axpy_py(py: Python, a: f64, x: &PyArray, y: &PyArray) -> PyResult { + fn axpy_py(py: Python, a: f64, x: &PyArray, y: &PyArray) -> PyResult> { let np = PyArrayModule::import(py)?; let x = x.as_array().into_pyresult("x must be f64 array")?; let y = y.as_array().into_pyresult("y must be f64 array")?; @@ -36,7 +36,7 @@ fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> { // wrapper of `mult` #[pyfn(m, "mult")] - fn mult_py(_py: Python, a: f64, x: &PyArray) -> PyResult<()> { + fn mult_py(_py: Python, a: f64, x: &PyArray) -> PyResult<()> { let x = x.as_array_mut().into_pyresult("x must be f64 array")?; mult(a, x); Ok(()) diff --git a/example/setup.py b/example/setup.py index c1290be63..e88f71857 100644 --- a/example/setup.py +++ b/example/setup.py @@ -4,7 +4,7 @@ import sys from setuptools import find_packages, setup from setuptools.command.test import test as TestCommand -from setuptools_rust import RustExtension, Binding +from setuptools_rust import RustExtension class CmdTest(TestCommand): diff --git a/src/array.rs b/src/array.rs index 434a6c59e..670212044 100644 --- a/src/array.rs +++ b/src/array.rs @@ -3,7 +3,7 @@ use ndarray::*; use npyffi; use pyo3::*; - +use std::marker::PhantomData; use std::os::raw::c_void; use std::ptr::null_mut; @@ -11,21 +11,144 @@ use super::error::ArrayCastError; use super::*; /// Untyped safe interface for NumPy ndarray. -pub struct PyArray(PyObject); -pyobject_native_type!(PyArray, *npyffi::PyArray_Type_Ptr, npyffi::PyArray_Check); +pub struct PyArray(PyObject, PhantomData); + +pyobject_native_type_convert!( + PyArray, + *npyffi::PyArray_Type_Ptr, + npyffi::PyArray_Check, + T +); + +pyobject_native_type_named!(PyArray, T); + +impl<'a, T> ::std::convert::From<&'a PyArray> for &'a PyObjectRef { + fn from(ob: &'a PyArray) -> Self { + unsafe { &*(ob as *const PyArray as *const PyObjectRef) } + } +} -impl IntoPyObject for PyArray { +impl<'a, T: TypeNum> FromPyObject<'a> for &'a PyArray { + // here we do type-check twice + // 1. Checks if the object is PyArray + // 2. Checks if the data type of the array is T + fn extract(ob: &'a PyObjectRef) -> PyResult { + let array = unsafe { + if npyffi::PyArray_Check(ob.as_ptr()) == 0 { + return Err(PyDowncastError.into()); + } + &*(ob as *const PyObjectRef as *const PyArray) + }; + array + .type_check() + .map(|_| array) + .map_err(|err| err.into_pyerr("FromPyObject::extract failed")) + } +} + +impl IntoPyObject for PyArray { fn into_object(self, _py: Python) -> PyObject { self.0 } } -impl PyArray { +impl PyArray { /// Get raw pointer for PyArrayObject pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject { self.as_ptr() as _ } + pub unsafe fn from_owned_ptr(py: Python, ptr: *mut pyo3::ffi::PyObject) -> Self { + let obj = PyObject::from_owned_ptr(py, ptr); + PyArray(obj, PhantomData) + } + + pub unsafe fn from_borrowed_ptr(py: Python, ptr: *mut pyo3::ffi::PyObject) -> Self { + let obj = PyObject::from_borrowed_ptr(py, ptr); + PyArray(obj, PhantomData) + } + + /// Returns the number of dimensions in the array. + /// + /// Same as [numpy.ndarray.ndim](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ndim.html) + /// + /// # Example + /// ``` + /// # extern crate pyo3; extern crate numpy; fn main() { + /// use numpy::{PyArray, PyArrayModule}; + /// let gil = pyo3::Python::acquire_gil(); + /// let np = PyArrayModule::import(gil.python()).unwrap(); + /// let arr = PyArray::::new(gil.python(), &np, &[4, 5, 6]); + /// assert_eq!(arr.ndim(), 3); + /// # } + /// ``` + // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_NDIM + pub fn ndim(&self) -> usize { + let ptr = self.as_array_ptr(); + unsafe { (*ptr).nd as usize } + } + + /// Returns a slice which contains how many bytes you need to jump to the next row. + /// + /// Same as [numpy.ndarray.strides](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.strides.html) + /// # Example + /// ``` + /// # extern crate pyo3; extern crate numpy; fn main() { + /// use numpy::{PyArray, PyArrayModule}; + /// let gil = pyo3::Python::acquire_gil(); + /// let np = PyArrayModule::import(gil.python()).unwrap(); + /// let arr = PyArray::::new(gil.python(), &np, &[4, 5, 6]); + /// assert_eq!(arr.strides(), &[240, 48, 8]); + /// # } + /// ``` + // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_STRIDES + pub fn strides(&self) -> &[isize] { + let n = self.ndim(); + let ptr = self.as_array_ptr(); + unsafe { + let p = (*ptr).strides; + ::std::slice::from_raw_parts(p, n) + } + } + + /// Returns a slice which contains dimmensions of the array. + /// + /// Same as [numpy.ndarray.shape](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.shape.html) + /// # Example + /// ``` + /// # extern crate pyo3; extern crate numpy; fn main() { + /// use numpy::{PyArray, PyArrayModule}; + /// let gil = pyo3::Python::acquire_gil(); + /// let np = PyArrayModule::import(gil.python()).unwrap(); + /// let arr = PyArray::::new(gil.python(), &np, &[4, 5, 6]); + /// assert_eq!(arr.shape(), &[4, 5, 6]); + /// # } + /// ``` + // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_DIMS + pub fn shape(&self) -> &[usize] { + let n = self.ndim(); + let ptr = self.as_array_ptr(); + unsafe { + let p = (*ptr).dimensions as *mut usize; + ::std::slice::from_raw_parts(p, n) + } + } + + /// Same as [shape](./struct.PyArray.html#method.shape) + /// + /// Reserved for backward compatibility. + #[inline] + pub fn dims(&self) -> &[usize] { + self.shape() + } + + /// Calcurates the total number of elements in the array. + pub fn len(&self) -> usize { + self.shape().iter().fold(1, |a, b| a * b) + } +} + +impl PyArray { /// Construct one-dimension PyArray from boxed slice. /// /// # Example @@ -35,11 +158,11 @@ impl PyArray { /// let gil = pyo3::Python::acquire_gil(); /// let np = PyArrayModule::import(gil.python()).unwrap(); /// let slice = vec![1, 2, 3, 4, 5].into_boxed_slice(); - /// let pyarray = PyArray::from_boxed_slice::(gil.python(), &np, slice); - /// assert_eq!(pyarray.as_slice::().unwrap(), &[1, 2, 3, 4, 5]); + /// let pyarray = PyArray::from_boxed_slice(gil.python(), &np, slice); + /// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]); /// # } /// ``` - pub fn from_boxed_slice(py: Python, np: &PyArrayModule, v: Box<[T]>) -> PyArray { + pub fn from_boxed_slice(py: Python, np: &PyArrayModule, v: Box<[T]>) -> PyArray { IntoPyArray::into_pyarray(v, py, np) } @@ -51,11 +174,11 @@ impl PyArray { /// use numpy::{PyArray, PyArrayModule}; /// let gil = pyo3::Python::acquire_gil(); /// let np = PyArrayModule::import(gil.python()).unwrap(); - /// let pyarray = PyArray::from_vec::(gil.python(), &np, vec![1, 2, 3, 4, 5]); - /// assert_eq!(pyarray.as_slice::().unwrap(), &[1, 2, 3, 4, 5]); + /// let pyarray = PyArray::from_vec(gil.python(), &np, vec![1, 2, 3, 4, 5]); + /// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]); /// # } /// ``` - pub fn from_vec(py: Python, np: &PyArrayModule, v: Vec) -> PyArray { + pub fn from_vec(py: Python, np: &PyArrayModule, v: Vec) -> PyArray { IntoPyArray::into_pyarray(v, py, np) } @@ -71,16 +194,16 @@ impl PyArray { /// let gil = pyo3::Python::acquire_gil(); /// let np = PyArrayModule::import(gil.python()).unwrap(); /// let vec2 = vec![vec![1, 2, 3]; 2]; - /// let pyarray = PyArray::from_vec2::(gil.python(), &np, &vec2).unwrap(); - /// assert_eq!(pyarray.as_array::().unwrap(), array![[1, 2, 3], [1, 2, 3]].into_dyn()); - /// assert!(PyArray::from_vec2::(gil.python(), &np, &vec![vec![1], vec![2, 3]]).is_err()); + /// let pyarray = PyArray::from_vec2(gil.python(), &np, &vec2).unwrap(); + /// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2, 3], [1, 2, 3]].into_dyn()); + /// assert!(PyArray::from_vec2(gil.python(), &np, &vec![vec![1], vec![2, 3]]).is_err()); /// # } /// ``` - pub fn from_vec2( + pub fn from_vec2( py: Python, np: &PyArrayModule, v: &Vec>, - ) -> Result { + ) -> Result, ArrayCastError> { let last_len = v.last().map_or(0, |v| v.len()); if v.iter().any(|v| v.len() != last_len) { return Err(ArrayCastError::FromVec); @@ -89,7 +212,7 @@ impl PyArray { let flattend: Vec<_> = v.iter().cloned().flatten().collect(); unsafe { let data = convert::into_raw(flattend); - Ok(PyArray::new_::(py, np, &dims, null_mut(), data)) + Ok(PyArray::new_(py, np, &dims, null_mut(), data)) } } @@ -105,19 +228,19 @@ impl PyArray { /// let gil = pyo3::Python::acquire_gil(); /// let np = PyArrayModule::import(gil.python()).unwrap(); /// let vec2 = vec![vec![vec![1, 2]; 2]; 2]; - /// let pyarray = PyArray::from_vec3::(gil.python(), &np, &vec2).unwrap(); + /// let pyarray = PyArray::from_vec3(gil.python(), &np, &vec2).unwrap(); /// assert_eq!( - /// pyarray.as_array::().unwrap(), + /// pyarray.as_array().unwrap(), /// array![[[1, 2], [1, 2]], [[1, 2], [1, 2]]].into_dyn() /// ); - /// assert!(PyArray::from_vec3::(gil.python(), &np, &vec![vec![vec![1], vec![]]]).is_err()); + /// assert!(PyArray::from_vec3(gil.python(), &np, &vec![vec![vec![1], vec![]]]).is_err()); /// # } /// ``` - pub fn from_vec3( + pub fn from_vec3( py: Python, np: &PyArrayModule, v: &Vec>>, - ) -> Result { + ) -> Result, ArrayCastError> { let dim2 = v.last().map_or(0, |v| v.len()); if v.iter().any(|v| v.len() != dim2) { return Err(ArrayCastError::FromVec); @@ -130,7 +253,7 @@ impl PyArray { let flattend: Vec<_> = v.iter().flat_map(|v| v.iter().cloned().flatten()).collect(); unsafe { let data = convert::into_raw(flattend); - Ok(PyArray::new_::(py, np, &dims, null_mut(), data)) + Ok(PyArray::new_(py, np, &dims, null_mut(), data)) } } @@ -142,118 +265,29 @@ impl PyArray { /// use numpy::{PyArray, PyArrayModule}; /// let gil = pyo3::Python::acquire_gil(); /// let np = PyArrayModule::import(gil.python()).unwrap(); - /// let pyarray = PyArray::from_ndarray::(gil.python(), &np, array![[1, 2], [3, 4]]); - /// assert_eq!(pyarray.as_array::().unwrap(), array![[1, 2], [3, 4]].into_dyn()); + /// let pyarray = PyArray::from_ndarray(gil.python(), &np, array![[1, 2], [3, 4]]); + /// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2], [3, 4]].into_dyn()); /// # } /// ``` - pub fn from_ndarray(py: Python, np: &PyArrayModule, arr: Array) -> PyArray + pub fn from_ndarray(py: Python, np: &PyArrayModule, arr: Array) -> PyArray where - A: TypeNum, D: Dimension, { IntoPyArray::into_pyarray(arr, py, np) } - pub unsafe fn from_owned_ptr(py: Python, ptr: *mut pyo3::ffi::PyObject) -> Self { - let obj = PyObject::from_owned_ptr(py, ptr); - PyArray(obj) - } - - pub unsafe fn from_borrowed_ptr(py: Python, ptr: *mut pyo3::ffi::PyObject) -> Self { - let obj = PyObject::from_borrowed_ptr(py, ptr); - PyArray(obj) - } - - /// Returns the number of dimensions in the array. - /// - /// Same as [numpy.ndarray.ndim](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ndim.html) - /// - /// # Example - /// ``` - /// # extern crate pyo3; extern crate numpy; fn main() { - /// use numpy::{PyArray, PyArrayModule}; - /// let gil = pyo3::Python::acquire_gil(); - /// let np = PyArrayModule::import(gil.python()).unwrap(); - /// let arr = PyArray::new::(gil.python(), &np, &[4, 5, 6]); - /// assert_eq!(arr.ndim(), 3); - /// # } - /// ``` - // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_NDIM - pub fn ndim(&self) -> usize { - let ptr = self.as_array_ptr(); - unsafe { (*ptr).nd as usize } - } - - /// Same as [shape](./struct.PyArray.html#method.shape) - /// - /// Reserved for backward compatibility. - #[inline] - pub fn dims(&self) -> &[usize] { - self.shape() - } - - pub fn len(&self) -> usize { - self.shape().iter().fold(1, |a, b| a * b) - } - - /// Returns a slice which contains dimmensions of the array. - /// - /// Same as [numpy.ndarray.shape](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.shape.html) - /// # Example - /// ``` - /// # extern crate pyo3; extern crate numpy; fn main() { - /// use numpy::{PyArray, PyArrayModule}; - /// let gil = pyo3::Python::acquire_gil(); - /// let np = PyArrayModule::import(gil.python()).unwrap(); - /// let arr = PyArray::new::(gil.python(), &np, &[4, 5, 6]); - /// assert_eq!(arr.shape(), &[4, 5, 6]); - /// # } - /// ``` - // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_DIMS - pub fn shape(&self) -> &[usize] { - let n = self.ndim(); - let ptr = self.as_array_ptr(); - unsafe { - let p = (*ptr).dimensions as *mut usize; - ::std::slice::from_raw_parts(p, n) - } - } - - /// Returns a slice which contains how many bytes you need to jump to the next row. - /// - /// Same as [numpy.ndarray.strides](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.strides.html) - /// # Example - /// ``` - /// # extern crate pyo3; extern crate numpy; fn main() { - /// use numpy::{PyArray, PyArrayModule}; - /// let gil = pyo3::Python::acquire_gil(); - /// let np = PyArrayModule::import(gil.python()).unwrap(); - /// let arr = PyArray::new::(gil.python(), &np, &[4, 5, 6]); - /// assert_eq!(arr.strides(), &[240, 48, 8]); - /// # } - /// ``` - // C API: https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_STRIDES - pub fn strides(&self) -> &[isize] { - let n = self.ndim(); - let ptr = self.as_array_ptr(); - unsafe { - let p = (*ptr).strides; - ::std::slice::from_raw_parts(p, n) - } - } - - unsafe fn data(&self) -> *mut T { + unsafe fn data(&self) -> *mut T { let ptr = self.as_array_ptr(); (*ptr).data as *mut T } - fn ndarray_shape(&self) -> StrideShape { + fn ndarray_shape(&self) -> StrideShape { // FIXME may be done more simply let shape: Shape<_> = Dim(self.shape()).into(); let st: Vec = self .strides() .iter() - .map(|&x| x as usize / ::std::mem::size_of::()) + .map(|&x| x as usize / ::std::mem::size_of::()) .collect(); shape.strides(Dim(st)) } @@ -265,51 +299,50 @@ impl PyArray { } } - fn type_check(&self) -> Result<(), ArrayCastError> { - let test = A::typenum(); + pub fn data_type(&self) -> NpyDataType { + NpyDataType::from_i32(self.typenum()) + } + + fn type_check(&self) -> Result<(), ArrayCastError> { + let test = T::typenum(); let truth = self.typenum(); - if A::typenum() == self.typenum() { + if test == truth { Ok(()) } else { - Err(ArrayCastError::to_rust(test, truth)) + Err(ArrayCastError::to_rust(truth, test)) } } /// Get data as a ndarray::ArrayView - pub fn as_array(&self) -> Result, ArrayCastError> { - self.type_check::()?; - unsafe { - Ok(ArrayView::from_shape_ptr( - self.ndarray_shape::(), - self.data(), - )) - } + pub fn as_array(&self) -> Result, ArrayCastError> { + self.type_check()?; + unsafe { Ok(ArrayView::from_shape_ptr(self.ndarray_shape(), self.data())) } } /// Get data as a ndarray::ArrayViewMut - pub fn as_array_mut(&self) -> Result, ArrayCastError> { - self.type_check::()?; + pub fn as_array_mut(&self) -> Result, ArrayCastError> { + self.type_check()?; unsafe { Ok(ArrayViewMut::from_shape_ptr( - self.ndarray_shape::(), + self.ndarray_shape(), self.data(), )) } } /// Get data as a Rust immutable slice - pub fn as_slice(&self) -> Result<&[A], ArrayCastError> { - self.type_check::()?; + pub fn as_slice(&self) -> Result<&[T], ArrayCastError> { + self.type_check()?; unsafe { Ok(::std::slice::from_raw_parts(self.data(), self.len())) } } /// Get data as a Rust mutable slice - pub fn as_slice_mut(&self) -> Result<&mut [A], ArrayCastError> { - self.type_check::()?; + pub fn as_slice_mut(&self) -> Result<&mut [T], ArrayCastError> { + self.type_check()?; unsafe { Ok(::std::slice::from_raw_parts_mut(self.data(), self.len())) } } - pub unsafe fn new_( + pub unsafe fn new_( py: Python, np: &PyArrayModule, dims: &[usize], @@ -332,17 +365,12 @@ impl PyArray { } /// a wrapper of [PyArray_SimpleNew](https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_SimpleNew) - pub fn new(py: Python, np: &PyArrayModule, dims: &[usize]) -> Self { - unsafe { Self::new_::(py, np, dims, null_mut(), null_mut()) } + pub fn new(py: Python, np: &PyArrayModule, dims: &[usize]) -> Self { + unsafe { Self::new_(py, np, dims, null_mut(), null_mut()) } } /// a wrapper of [PyArray_ZEROS](https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_ZEROS) - pub fn zeros( - py: Python, - np: &PyArrayModule, - dims: &[usize], - order: NPY_ORDER, - ) -> Self { + pub fn zeros(py: Python, np: &PyArrayModule, dims: &[usize], order: NPY_ORDER) -> Self { let dims: Vec = dims.iter().map(|d| *d as npy_intp).collect(); unsafe { let descr = np.PyArray_DescrFromType(T::typenum()); @@ -357,13 +385,7 @@ impl PyArray { } /// a wrapper of [PyArray_Arange](https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_Arange) - pub fn arange( - py: Python, - np: &PyArrayModule, - start: f64, - stop: f64, - step: f64, - ) -> Self { + pub fn arange(py: Python, np: &PyArrayModule, start: f64, stop: f64, step: f64) -> Self { unsafe { let ptr = np.PyArray_Arange(start, stop, step, T::typenum()); Self::from_owned_ptr(py, ptr) diff --git a/src/convert.rs b/src/convert.rs index 85b8fbdcd..6ed79fb6e 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -9,34 +9,39 @@ use std::ptr::null_mut; use super::*; pub trait IntoPyArray { - fn into_pyarray(self, Python, &PyArrayModule) -> PyArray; + type Item: TypeNum; + fn into_pyarray(self, Python, &PyArrayModule) -> PyArray; } impl IntoPyArray for Box<[T]> { - fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { + type Item = T; + fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { let dims = [self.len()]; let ptr = Box::into_raw(self); - unsafe { PyArray::new_::(py, np, &dims, null_mut(), ptr as *mut c_void) } + unsafe { PyArray::new_(py, np, &dims, null_mut(), ptr as *mut c_void) } } } impl IntoPyArray for Vec { - fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { + type Item = T; + fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { let dims = [self.len()]; - unsafe { PyArray::new_::(py, np, &dims, null_mut(), into_raw(self)) } + unsafe { PyArray::new_(py, np, &dims, null_mut(), into_raw(self)) } } } impl IntoPyArray for Array { - fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { + type Item = A; + fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { let dims: Vec<_> = self.shape().iter().cloned().collect(); - let mut strides: Vec<_> = self.strides() + let mut strides: Vec<_> = self + .strides() .into_iter() .map(|n| n * size_of::() as npy_intp) .collect(); unsafe { let data = into_raw(self.into_raw_vec()); - PyArray::new_::(py, np, &dims, strides.as_mut_ptr(), data) + PyArray::new_(py, np, &dims, strides.as_mut_ptr(), data) } } } @@ -45,11 +50,12 @@ macro_rules! array_impls { ($($N: expr)+) => { $( impl IntoPyArray for [T; $N] { - fn into_pyarray(mut self, py: Python, np: &PyArrayModule) -> PyArray { + type Item = T; + fn into_pyarray(mut self, py: Python, np: &PyArrayModule) -> PyArray { let dims = [$N]; let ptr = &mut self as *mut [T; $N]; unsafe { - PyArray::new_::(py, np, &dims, null_mut(), ptr as *mut c_void) + PyArray::new_(py, np, &dims, null_mut(), ptr as *mut c_void) } } } @@ -64,23 +70,23 @@ array_impls! { 30 31 32 } - pub(crate) unsafe fn into_raw(x: Vec) -> *mut c_void { let ptr = Box::into_raw(x.into_boxed_slice()); ptr as *mut c_void } pub trait ToPyArray { - fn to_pyarray(self, Python, &PyArrayModule) -> PyArray; + type Item: TypeNum; + fn to_pyarray(self, Python, &PyArrayModule) -> PyArray; } impl ToPyArray for Iter where Iter: Iterator + Sized, { - fn to_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { + type Item = T; + fn to_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray { let vec: Vec = self.collect(); vec.into_pyarray(py, np) } } - diff --git a/src/error.rs b/src/error.rs index dd24f23da..23c8091a3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ use pyo3::*; use std::error; use std::fmt; +use types::NpyDataType; pub trait IntoPyErr { fn into_pyerr(self, msg: &str) -> PyErr; @@ -23,21 +24,27 @@ impl IntoPyResult for Result { /// Error for casting `PyArray` into `ArrayView` or `ArrayViewMut` #[derive(Debug)] pub enum ArrayCastError { - ToRust { test: i32, truth: i32 }, + ToRust { + from: NpyDataType, + to: NpyDataType, + }, FromVec, } impl ArrayCastError { - pub fn to_rust(test: i32, truth: i32) -> Self { - ArrayCastError::ToRust { test, truth } + pub fn to_rust(from: i32, to: i32) -> Self { + ArrayCastError::ToRust { + from: NpyDataType::from_i32(from), + to: NpyDataType::from_i32(to), + } } } impl fmt::Display for ArrayCastError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - ArrayCastError::ToRust { test, truth } => { - write!(f, "Cast failed: from={}, to={}", test, truth) + ArrayCastError::ToRust { from, to } => { + write!(f, "Cast failed: from={:?}, to={:?}", from, to) } ArrayCastError::FromVec => write!(f, "Cast failed: FromVec (maybe invalid dimension)"), } @@ -56,10 +63,11 @@ impl error::Error for ArrayCastError { impl IntoPyErr for ArrayCastError { fn into_pyerr(self, msg: &str) -> PyErr { let msg = match self { - ArrayCastError::ToRust { .. } => { - format!("rust_numpy::ArrayCastError::IntoArray: {}", msg) - } - ArrayCastError::FromVec => format!("rust_numpy::ArrayCastError::FromVec: {}", msg), + ArrayCastError::ToRust { from, to } => format!( + "ArrayCastError::ToRust: from: {:?}, to: {:?}, msg: {}", + from, to, msg + ), + ArrayCastError::FromVec => format!("ArrayCastError::FromVec: {}", msg), }; PyErr::new::(msg) } diff --git a/src/npyffi/array.rs b/src/npyffi/array.rs index 3ab740f55..06d244177 100644 --- a/src/npyffi/array.rs +++ b/src/npyffi/array.rs @@ -60,6 +60,32 @@ impl<'py> PyArrayModule<'py> { Ok(mod_) } + /// Returns internal `PyModule` type, which includes `numpy.core.multiarray`, + /// so that you can use `PyArrayModule` with some pyo3 functions. + /// + /// # Example + /// + /// ``` + /// # extern crate numpy; extern crate pyo3; fn main() { + /// use numpy::*; + /// use pyo3::prelude::*; + /// let gil = pyo3::Python::acquire_gil(); + /// let np = PyArrayModule::import(gil.python()).unwrap(); + /// let dict = PyDict::new(gil.python()); + /// dict.set_item("np", np.as_pymodule()).unwrap(); + /// let pyarray: &PyArray = gil + /// .python() + /// .eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None) + /// .unwrap() + /// .extract() + /// .unwrap(); + /// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3]); + /// # } + /// ``` + pub fn as_pymodule(&self) -> &'py PyModule { + self.numpy + } + pyarray_api![0; PyArray_GetNDArrayCVersion() -> c_uint]; pyarray_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int]; pyarray_api![41; PyArray_GetNumericOps() -> *mut PyObject]; diff --git a/src/npyffi/ufunc.rs b/src/npyffi/ufunc.rs index ba9c2f233..4b825853d 100644 --- a/src/npyffi/ufunc.rs +++ b/src/npyffi/ufunc.rs @@ -57,6 +57,10 @@ impl<'py> PyUFuncModule<'py> { *self.api as *mut PyTypeObject } + pub fn as_pymodule(&self) -> &'py PyModule { + self.numpy + } + pyufunc_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject]; pyufunc_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int]; pyufunc_api![3; PyUFunc_GenericFunction(ufunc: *mut PyUFuncObject, args: *mut PyObject, kwds: *mut PyObject, op: *mut *mut PyArrayObject) -> c_int]; diff --git a/src/types.rs b/src/types.rs index 5ba0c2fe3..7694a0763 100644 --- a/src/types.rs +++ b/src/types.rs @@ -7,29 +7,67 @@ pub use super::npyffi::NPY_ORDER::{NPY_CORDER, NPY_FORTRANORDER}; use super::npyffi::NPY_TYPES; +/// An enum type represents numpy data type. +/// +/// This type is mainly for displaying error, and user don't have to use it directly. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum NpyDataType { + Bool, + Int32, + Int64, + Uint32, + Uint64, + Float32, + Float64, + Complex32, + Complex64, + Unsupported, +} + +impl NpyDataType { + pub(crate) fn from_i32(npy_t: i32) -> Self { + match npy_t { + x if x == NPY_TYPES::NPY_BOOL as i32 => NpyDataType::Bool, + x if x == NPY_TYPES::NPY_INT as i32 => NpyDataType::Int32, + x if x == NPY_TYPES::NPY_LONG as i32 => NpyDataType::Int64, + x if x == NPY_TYPES::NPY_UINT as i32 => NpyDataType::Uint32, + x if x == NPY_TYPES::NPY_ULONG as i32 => NpyDataType::Uint64, + x if x == NPY_TYPES::NPY_FLOAT as i32 => NpyDataType::Float32, + x if x == NPY_TYPES::NPY_DOUBLE as i32 => NpyDataType::Float64, + x if x == NPY_TYPES::NPY_CFLOAT as i32 => NpyDataType::Complex32, + x if x == NPY_TYPES::NPY_CDOUBLE as i32 => NpyDataType::Complex64, + _ => NpyDataType::Unsupported, + } + } +} + pub trait TypeNum: Clone { fn typenum_enum() -> NPY_TYPES; fn typenum() -> i32 { Self::typenum_enum() as i32 } + fn to_npy_data_type(self) -> NpyDataType; } macro_rules! impl_type_num { - ($t:ty, $npy_t:ident) => { + ($t:ty, $npy_t:ident, $npy_dat_t:ident) => { impl TypeNum for $t { fn typenum_enum() -> NPY_TYPES { NPY_TYPES::$npy_t } + fn to_npy_data_type(self) -> NpyDataType { + NpyDataType::$npy_dat_t + } } }; } // impl_type_num! -impl_type_num!(bool, NPY_BOOL); -impl_type_num!(i32, NPY_INT); -impl_type_num!(i64, NPY_LONG); -impl_type_num!(u32, NPY_UINT); -impl_type_num!(u64, NPY_ULONG); -impl_type_num!(f32, NPY_FLOAT); -impl_type_num!(f64, NPY_DOUBLE); -impl_type_num!(c32, NPY_CFLOAT); -impl_type_num!(c64, NPY_CDOUBLE); +impl_type_num!(bool, NPY_BOOL, Bool); +impl_type_num!(i32, NPY_INT, Int32); +impl_type_num!(i64, NPY_LONG, Int64); +impl_type_num!(u32, NPY_UINT, Uint32); +impl_type_num!(u64, NPY_ULONG, Uint64); +impl_type_num!(f32, NPY_FLOAT, Float32); +impl_type_num!(f64, NPY_DOUBLE, Float64); +impl_type_num!(c32, NPY_CFLOAT, Complex32); +impl_type_num!(c64, NPY_CDOUBLE, Complex64); diff --git a/tests/array.rs b/tests/array.rs index 853aaf9e1..aecda2ca9 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -4,6 +4,7 @@ extern crate pyo3; use ndarray::*; use numpy::*; +use pyo3::prelude::*; #[test] fn new() { @@ -11,7 +12,7 @@ fn new() { let np = PyArrayModule::import(gil.python()).unwrap(); let n = 3; let m = 5; - let arr = PyArray::new::(gil.python(), &np, &[n, m]); + let arr = PyArray::::new(gil.python(), &np, &[n, m]); assert!(arr.ndim() == 2); assert!(arr.dims() == [n, m]); assert!(arr.strides() == [m as isize * 8, 8]); @@ -23,12 +24,12 @@ fn zeros() { let np = PyArrayModule::import(gil.python()).unwrap(); let n = 3; let m = 5; - let arr = PyArray::zeros::(gil.python(), &np, &[n, m], NPY_CORDER); + let arr = PyArray::::zeros(gil.python(), &np, &[n, m], NPY_CORDER); assert!(arr.ndim() == 2); assert!(arr.dims() == [n, m]); assert!(arr.strides() == [m as isize * 8, 8]); - let arr = PyArray::zeros::(gil.python(), &np, &[n, m], NPY_FORTRANORDER); + let arr = PyArray::::zeros(gil.python(), &np, &[n, m], NPY_FORTRANORDER); assert!(arr.ndim() == 2); assert!(arr.dims() == [n, m]); assert!(arr.strides() == [8, n as isize * 8]); @@ -38,18 +39,18 @@ fn zeros() { fn arange() { let gil = pyo3::Python::acquire_gil(); let np = PyArrayModule::import(gil.python()).unwrap(); - let arr = PyArray::arange::(gil.python(), &np, 0.0, 1.0, 0.1); + let arr = PyArray::::arange(gil.python(), &np, 0.0, 1.0, 0.1); println!("ndim = {:?}", arr.ndim()); println!("dims = {:?}", arr.dims()); - println!("array = {:?}", arr.as_slice::().unwrap()); + println!("array = {:?}", arr.as_slice().unwrap()); } #[test] fn as_array() { let gil = pyo3::Python::acquire_gil(); let np = PyArrayModule::import(gil.python()).unwrap(); - let arr = PyArray::zeros::(gil.python(), &np, &[3, 2, 4], NPY_CORDER); - let a = arr.as_array::().unwrap(); + let arr = PyArray::::zeros(gil.python(), &np, &[3, 2, 4], NPY_CORDER); + let a = arr.as_array().unwrap(); assert_eq!(arr.shape(), a.shape()); assert_eq!( arr.strides().iter().map(|x| x / 8).collect::>(), @@ -57,15 +58,6 @@ fn as_array() { ); } -#[test] -#[should_panic] -fn as_array_panic() { - let gil = pyo3::Python::acquire_gil(); - let np = PyArrayModule::import(gil.python()).unwrap(); - let arr = PyArray::zeros::(gil.python(), &np, &[3, 2, 4], NPY_CORDER); - let _a = arr.as_array::().unwrap(); -} - #[test] fn into_pyarray_vec() { let gil = pyo3::Python::acquire_gil(); @@ -74,7 +66,7 @@ fn into_pyarray_vec() { let a = vec![1, 2, 3]; let arr = a.into_pyarray(gil.python(), &np); println!("arr.shape = {:?}", arr.shape()); - println!("arr = {:?}", arr.as_slice::().unwrap()); + println!("arr = {:?}", arr.as_slice().unwrap()); assert_eq!(arr.shape(), [3]); } @@ -101,7 +93,7 @@ fn iter_to_pyarray() { let np = PyArrayModule::import(gil.python()).unwrap(); let arr = (0..10).map(|x| x * x).to_pyarray(gil.python(), &np); println!("arr.shape = {:?}", arr.shape()); - println!("arr = {:?}", arr.as_slice::().unwrap()); + println!("arr = {:?}", arr.as_slice().unwrap()); assert_eq!(arr.shape(), [10]); } @@ -110,8 +102,8 @@ fn is_instance() { let gil = pyo3::Python::acquire_gil(); let py = gil.python(); let np = PyArrayModule::import(py).unwrap(); - let arr = PyArray::new::(gil.python(), &np, &[3, 5]); - assert!(py.is_instance::(&arr).unwrap()); + let arr = PyArray::::new(gil.python(), &np, &[3, 5]); + assert!(py.is_instance::, _>(&arr).unwrap()); assert!(!py.is_instance::(&arr).unwrap()); } @@ -120,9 +112,12 @@ fn from_vec2() { let vec2 = vec![vec![1, 2, 3]; 2]; let gil = pyo3::Python::acquire_gil(); let np = PyArrayModule::import(gil.python()).unwrap(); - let pyarray = PyArray::from_vec2::(gil.python(), &np, &vec2).unwrap(); - assert_eq!(pyarray.as_array::().unwrap(), array![[1, 2, 3], [1, 2, 3]].into_dyn()); - assert!(PyArray::from_vec2::(gil.python(), &np, &vec![vec![1], vec![2, 3]]).is_err()); + let pyarray = PyArray::from_vec2(gil.python(), &np, &vec2).unwrap(); + assert_eq!( + pyarray.as_array().unwrap(), + array![[1, 2, 3], [1, 2, 3]].into_dyn() + ); + assert!(PyArray::from_vec2(gil.python(), &np, &vec![vec![1], vec![2, 3]]).is_err()); } #[test] @@ -130,19 +125,48 @@ fn from_vec3() { let gil = pyo3::Python::acquire_gil(); let np = PyArrayModule::import(gil.python()).unwrap(); let vec3 = vec![vec![vec![1, 2]; 2]; 2]; - let pyarray = PyArray::from_vec3::(gil.python(), &np, &vec3).unwrap(); + let pyarray = PyArray::from_vec3(gil.python(), &np, &vec3).unwrap(); assert_eq!( - pyarray.as_array::().unwrap(), + pyarray.as_array().unwrap(), array![[[1, 2], [1, 2]], [[1, 2], [1, 2]]].into_dyn() ); } - #[test] fn from_small_array() { let gil = pyo3::Python::acquire_gil(); let np = PyArrayModule::import(gil.python()).unwrap(); let array: [i32; 5] = [1, 2, 3, 4, 5]; let pyarray = array.into_pyarray(gil.python(), &np); - assert_eq!(pyarray.as_slice::().unwrap(), &[1, 2, 3, 4, 5]); + assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]); } + +#[test] +fn from_eval() { + let gil = pyo3::Python::acquire_gil(); + let np = PyArrayModule::import(gil.python()).unwrap(); + let dict = PyDict::new(gil.python()); + dict.set_item("np", np.as_pymodule()).unwrap(); + let pyarray: &PyArray = gil + .python() + .eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None) + .unwrap() + .extract() + .unwrap(); + assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3]); +} + +#[test] +fn from_eval_fail() { + let gil = pyo3::Python::acquire_gil(); + let np = PyArrayModule::import(gil.python()).unwrap(); + let dict = PyDict::new(gil.python()); + dict.set_item("np", np.as_pymodule()).unwrap(); + let converted: Result<&PyArray, _> = gil + .python() + .eval("np.array([1, 2, 3], dtype='float64')", Some(&dict), None) + .unwrap() + .extract(); + assert!(converted.is_err()); +} +