Skip to content

Commit

Permalink
Merge pull request #5 from kornia/feat/numpy
Browse files Browse the repository at this point in the history
implement __dlpack__, __dlpack_device__, for numpy conversions
  • Loading branch information
edgarriba committed Mar 13, 2022
2 parents a5fe268 + c301d30 commit d00ee61
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 34 deletions.
12 changes: 6 additions & 6 deletions README.md
Expand Up @@ -24,15 +24,15 @@ The visualisation API is based on `vviz`: https://github.com/strasdat/vviz
import kornia_rs as K
from kornia_rs import Tensor as cvTensor

img_path: Path = DATA_DIR / "dog.jpeg"
cv_tensor: cvTensor = K.read_image_jpeg(str(img_path.absolute()))
cv_tensor: cvTensor = K.read_image_jpeg("dog.jpeg")
assert cv_tensor.shape == [195, 258, 3]

# convert to dlpack to import to torch
# NOTE: later we will support to numpy, jax and mxnet.
dlpack = K.cvtensor_to_dlpack(cv_tensor)
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
# convert to dlpack to import to torch and numpy
# NOTE: later we will support to jax and mxnet.
th_tensor = torch.utils.dlpack.from_dlpack(cv_tensor)
np_tensor = np._from_dlpack(cv_tensor)
assert th_tensor.shape == (195, 258, 3)
assert np_tensor.shape == (195, 258, 3)
```

## TODO
Expand Down
25 changes: 3 additions & 22 deletions src/dlpack_py.rs
Expand Up @@ -24,17 +24,7 @@ unsafe extern "C" fn destructor(o: *mut pyo3::ffi::PyObject) {
// println!("Delete by Python");
}

unsafe extern "C" fn deleter(x: *mut dlpack::DLManagedTensor) {
// println!("DLManagedTensor deleter");

let ctx = (*x).manager_ctx as *mut cv::Tensor;
ctx.drop_in_place();
(*x).dl_tensor.shape.drop_in_place();
(*x).dl_tensor.strides.drop_in_place();
x.drop_in_place();
}

fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor {
pub fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor {
dlpack::DLTensor {
data: x.data.as_ptr() as *mut c_void,
device: dlpack::DLDevice {
Expand All @@ -54,17 +44,8 @@ fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor {
}

#[pyfunction]
pub fn cvtensor_to_dlpack(x: cv::Tensor) -> PyResult<*mut pyo3::ffi::PyObject> {
let tensor_bx = Box::new(x);
let dl_tensor = cvtensor_to_dltensor(&tensor_bx);

// create dlpack managed tensor
let dlm_tensor = dlpack::DLManagedTensor {
dl_tensor,
manager_ctx: Box::into_raw(tensor_bx) as *mut c_void,
deleter: Some(deleter),
};

pub fn cvtensor_to_dlpack(x: &cv::Tensor) -> PyResult<*mut pyo3::ffi::PyObject> {
let dlm_tensor: dlpack::DLManagedTensor = x.to_dlpack();
let dlm_tensor_bx = Box::new(dlm_tensor);

let name = CString::new("dltensor").unwrap();
Expand Down
41 changes: 41 additions & 0 deletions src/tensor.rs
@@ -1,6 +1,20 @@
pub mod cv {

use crate::dlpack;
use crate::dlpack_py::{cvtensor_to_dlpack, cvtensor_to_dltensor};
use pyo3::prelude::*;
use std::ffi::c_void;

// in our case we don not want to delete the data
unsafe extern "C" fn deleter(_: *mut dlpack::DLManagedTensor) {
// println!("DLManagedTensor deleter");

//let ctx = (*x).manager_ctx as *mut Tensor;
//ctx.drop_in_place();
//(*x).dl_tensor.shape.drop_in_place();
//(*x).dl_tensor.strides.drop_in_place();
//x.drop_in_place();
}

fn get_strides_from_shape(shape: &[i64]) -> Vec<i64> {
let mut strides = vec![0i64; shape.len()];
Expand Down Expand Up @@ -37,6 +51,33 @@ pub mod cv {
strides,
}
}

#[pyo3(name = "__dlpack__")]
pub fn to_dlpack_py(&self) -> PyResult<*mut pyo3::ffi::PyObject> {
cvtensor_to_dlpack(self)
}

#[pyo3(name = "__dlpack_device__")]
pub fn to_dlpack_device_py(&self) -> (u32, i32) {
let tensor_bx = Box::new(self);
let dl_tensor = cvtensor_to_dltensor(&tensor_bx);
(dl_tensor.device.device_type, dl_tensor.device.device_id)
}
}

impl Tensor {
pub fn to_dlpack(&self) -> dlpack::DLManagedTensor {
let tensor_bx = Box::new(self);
let dl_tensor = cvtensor_to_dltensor(&tensor_bx);

// create dlpack managed tensor

dlpack::DLManagedTensor {
dl_tensor,
manager_ctx: Box::into_raw(tensor_bx) as *mut c_void,
deleter: Some(deleter),
}
}
}
} // namespace cv

Expand Down
12 changes: 7 additions & 5 deletions test/test_io.py
Expand Up @@ -20,11 +20,13 @@ def test_read_image_jpeg():
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
assert th_tensor.shape == (195, 258, 3)

# TODO: needs to be fixed
# test __dlpack__()
torch.testing.assert_close(
th_tensor, torch.utils.dlpack.from_dlpack(cv_tensor))

# convert to dlpack to import to numpy
#dlpack = K.cvtensor_to_dlpack(cv_tensor)
#np_array = np._from_dlpack(dlpack)
#assert np_array.shape == (195, 258, 3)
np_array = np._from_dlpack(cv_tensor)
assert np_array.shape == (195, 258, 3)

def test_read_image_rs():
# load an image with image-rs
Expand All @@ -35,4 +37,4 @@ def test_read_image_rs():
# convert to dlpack to import to torch
dlpack = K.cvtensor_to_dlpack(cv_tensor)
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
assert th_tensor.shape == (195, 258, 3)
assert th_tensor.shape == (195, 258, 3)
18 changes: 17 additions & 1 deletion test/test_tensor.py
Expand Up @@ -2,6 +2,7 @@
from kornia_rs import Tensor as cvTensor

import torch
import numpy as np

def test_smoke():
# dumy test
Expand All @@ -12,7 +13,22 @@ def test_smoke():
assert len(data) == len(cv_tensor.data)
assert cv_tensor.strides == [6, 3, 1]

# to dlpack / torch
def test_conversions():
H, W, C = 2, 2, 3
data = [i for i in range(H * W * C)]
cv_tensor = cvTensor([H, W, C], data)

# to dlpack / torch / numpy
dlpack = K.cvtensor_to_dlpack(cv_tensor)
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
assert [x for x in th_tensor.shape] == cv_tensor.shape

def test_conversions2():
H, W, C = 2, 2, 3
data = [i for i in range(H * W * C)]
cv_tensor = cvTensor([H, W, C], data)

# to dlpack / torch / numpy
th_tensor = torch.utils.dlpack.from_dlpack(cv_tensor)
np_array = np._from_dlpack(cv_tensor)
np.testing.assert_array_equal(np_array, th_tensor.numpy())

0 comments on commit d00ee61

Please sign in to comment.