Skip to content

Commit

Permalink
Merge pull request #330 from PyO3/recovering-polymorphism
Browse files Browse the repository at this point in the history
RFC: Extend simple exmaple to include a function with limited polymorphism based on enums and FromPyObject.
  • Loading branch information
adamreichold committed Jun 25, 2022
2 parents 477c9d4 + 5f79d24 commit 93a3aaf
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
53 changes: 51 additions & 2 deletions examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, Zip};
use std::ops::Add;

use numpy::ndarray::{Array1, ArrayD, ArrayView1, ArrayViewD, ArrayViewMutD, Zip};
use numpy::{
datetime::{units, Timedelta},
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArray1, PyReadonlyArrayDyn,
Expand All @@ -7,7 +9,7 @@ use numpy::{
use pyo3::{
pymodule,
types::{PyDict, PyModule},
PyResult, Python,
FromPyObject, PyAny, PyResult, Python,
};

#[pymodule]
Expand All @@ -27,6 +29,11 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
x.map(|c| c.conj())
}

// example using generics
fn generic_add<T: Copy + Add<Output = T>>(x: ArrayView1<T>, y: ArrayView1<T>) -> Array1<T> {
&x + &y
}

// wrapper of `axpy`
#[pyfn(m)]
#[pyo3(name = "axpy")]
Expand Down Expand Up @@ -84,5 +91,47 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
.apply(|x, y| *x = (i64::from(*x) + 60 * i64::from(*y)).into());
}

// This crate follows a strongly-typed approach to wrapping NumPy arrays
// while Python API are often expected to work with multiple element types.
//
// That kind of limited polymorphis can be recovered by accepting an enumerated type
// covering the supported element types and dispatching into a generic implementation.
#[derive(FromPyObject)]
enum SupportedArray<'py> {
F64(&'py PyArray1<f64>),
I64(&'py PyArray1<i64>),
}

#[pyfn(m)]
fn polymorphic_add<'py>(
x: SupportedArray<'py>,
y: SupportedArray<'py>,
) -> PyResult<&'py PyAny> {
match (x, y) {
(SupportedArray::F64(x), SupportedArray::F64(y)) => Ok(generic_add(
x.readonly().as_array(),
y.readonly().as_array(),
)
.into_pyarray(x.py())
.into()),
(SupportedArray::I64(x), SupportedArray::I64(y)) => Ok(generic_add(
x.readonly().as_array(),
y.readonly().as_array(),
)
.into_pyarray(x.py())
.into()),
(SupportedArray::F64(x), SupportedArray::I64(y))
| (SupportedArray::I64(y), SupportedArray::F64(x)) => {
let y = y.cast::<f64>(false)?;

Ok(
generic_add(x.readonly().as_array(), y.readonly().as_array())
.into_pyarray(x.py())
.into(),
)
}
}
}

Ok(())
}
19 changes: 18 additions & 1 deletion examples/simple/tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds, polymorphic_add


def test_axpy():
Expand Down Expand Up @@ -33,3 +33,20 @@ def test_add_minutes_to_seconds():
add_minutes_to_seconds(x, y)

assert np.all(x == np.array([70, 140, 210], dtype="timedelta64[s]"))


def test_polymorphic_add():
x = np.array([1.0, 2.0, 3.0], dtype=np.double)
y = np.array([3.0, 3.0, 3.0], dtype=np.double)
z = polymorphic_add(x, y)
np.testing.assert_array_almost_equal(z, np.array([4.0, 5.0, 6.0], dtype=np.double))

x = np.array([1, 2, 3], dtype=np.int64)
y = np.array([3, 3, 3], dtype=np.int64)
z = polymorphic_add(x, y)
assert np.all(z == np.array([4, 5, 6], dtype=np.int64))

x = np.array([1.0, 2.0, 3.0], dtype=np.double)
y = np.array([3, 3, 3], dtype=np.int64)
z = polymorphic_add(x, y)
np.testing.assert_array_almost_equal(z, np.array([4.0, 5.0, 6.0], dtype=np.double))

0 comments on commit 93a3aaf

Please sign in to comment.