Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jul 29, 2018
1 parent 5b80925 commit 4e5a8fe
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 28 deletions.
23 changes: 15 additions & 8 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,20 @@ impl<'a, T> ::std::convert::From<&'a PyArray<T>> for &'a PyObjectRef {
}

impl<'a, T: TypeNum> FromPyObject<'a> for &'a PyArray<T> {
// here we do type-check twice
// 1. Checks if the object is PyArray
// 2. Checks if the data type of the array is T
fn extract(ob: &'a PyObjectRef) -> PyResult<Self> {
unsafe {
let array = unsafe {
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
return Err(PyDowncastError.into());
}
let array = &*(ob as *const PyObjectRef as *const PyArray<T>);
println!(">_<");
array
.type_check()
.map(|_| array)
.map_err(|err| err.into_pyerr("FromPyObject::extract failed"))
}
&*(ob as *const PyObjectRef as *const PyArray<T>)
};
array
.type_check()
.map(|_| array)
.map_err(|err| err.into_pyerr("FromPyObject::extract failed"))
}
}

Expand Down Expand Up @@ -140,6 +142,7 @@ impl<T> PyArray<T> {
self.shape()
}

/// Calcurates the total number of elements in the array.
pub fn len(&self) -> usize {
self.shape().iter().fold(1, |a, b| a * b)
}
Expand Down Expand Up @@ -296,6 +299,10 @@ impl<T: TypeNum> PyArray<T> {
}
}

pub fn data_type(&self) -> NpyDataType {
NpyDataType::from_i32(self.typenum())
}

fn type_check(&self) -> Result<(), ArrayCastError> {
let test = T::typenum();
let truth = self.typenum();
Expand Down
10 changes: 5 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use pyo3::*;
use std::error;
use std::fmt;
use types::NpyDataTypes;
use types::NpyDataType;

pub trait IntoPyErr {
fn into_pyerr(self, msg: &str) -> PyErr;
Expand All @@ -25,17 +25,17 @@ impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
#[derive(Debug)]
pub enum ArrayCastError {
ToRust {
from: NpyDataTypes,
to: NpyDataTypes,
from: NpyDataType,
to: NpyDataType,
},
FromVec,
}

impl ArrayCastError {
pub fn to_rust(from: i32, to: i32) -> Self {
ArrayCastError::ToRust {
from: NpyDataTypes::from_i32(from),
to: NpyDataTypes::from_i32(to),
from: NpyDataType::from_i32(from),
to: NpyDataType::from_i32(to),
}
}
}
Expand Down
30 changes: 15 additions & 15 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::npyffi::NPY_TYPES;
///
/// This type is mainly for displaying error, and user don't have to use it directly.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum NpyDataTypes {
pub enum NpyDataType {
Bool,
Int32,
Int64,
Expand All @@ -24,19 +24,19 @@ pub enum NpyDataTypes {
Unsupported,
}

impl NpyDataTypes {
impl NpyDataType {
pub(crate) fn from_i32(npy_t: i32) -> Self {
match npy_t {
x if x == NPY_TYPES::NPY_BOOL as i32 => NpyDataTypes::Bool,
x if x == NPY_TYPES::NPY_INT as i32 => NpyDataTypes::Int32,
x if x == NPY_TYPES::NPY_LONG as i32 => NpyDataTypes::Int64,
x if x == NPY_TYPES::NPY_UINT as i32 => NpyDataTypes::Uint32,
x if x == NPY_TYPES::NPY_ULONG as i32 => NpyDataTypes::Uint64,
x if x == NPY_TYPES::NPY_FLOAT as i32 => NpyDataTypes::Float32,
x if x == NPY_TYPES::NPY_DOUBLE as i32 => NpyDataTypes::Float64,
x if x == NPY_TYPES::NPY_CFLOAT as i32 => NpyDataTypes::Complex32,
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => NpyDataTypes::Complex64,
_ => NpyDataTypes::Unsupported,
x if x == NPY_TYPES::NPY_BOOL as i32 => NpyDataType::Bool,
x if x == NPY_TYPES::NPY_INT as i32 => NpyDataType::Int32,
x if x == NPY_TYPES::NPY_LONG as i32 => NpyDataType::Int64,
x if x == NPY_TYPES::NPY_UINT as i32 => NpyDataType::Uint32,
x if x == NPY_TYPES::NPY_ULONG as i32 => NpyDataType::Uint64,
x if x == NPY_TYPES::NPY_FLOAT as i32 => NpyDataType::Float32,
x if x == NPY_TYPES::NPY_DOUBLE as i32 => NpyDataType::Float64,
x if x == NPY_TYPES::NPY_CFLOAT as i32 => NpyDataType::Complex32,
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => NpyDataType::Complex64,
_ => NpyDataType::Unsupported,
}
}
}
Expand All @@ -46,7 +46,7 @@ pub trait TypeNum: Clone {
fn typenum() -> i32 {
Self::typenum_enum() as i32
}
fn to_npy_data_type(self) -> NpyDataTypes;
fn to_npy_data_type(self) -> NpyDataType;
}

macro_rules! impl_type_num {
Expand All @@ -55,8 +55,8 @@ macro_rules! impl_type_num {
fn typenum_enum() -> NPY_TYPES {
NPY_TYPES::$npy_t
}
fn to_npy_data_type(self) -> NpyDataTypes {
NpyDataTypes::$npy_dat_t
fn to_npy_data_type(self) -> NpyDataType {
NpyDataType::$npy_dat_t
}
}
};
Expand Down

0 comments on commit 4e5a8fe

Please sign in to comment.