diff --git a/README.md b/README.md index 95c063e3..efb20a42 100644 --- a/README.md +++ b/README.md @@ -24,15 +24,15 @@ The visualisation API is based on `vviz`: https://github.com/strasdat/vviz import kornia_rs as K from kornia_rs import Tensor as cvTensor - img_path: Path = DATA_DIR / "dog.jpeg" - cv_tensor: cvTensor = K.read_image_jpeg(str(img_path.absolute())) + cv_tensor: cvTensor = K.read_image_jpeg("dog.jpeg") assert cv_tensor.shape == [195, 258, 3] - # convert to dlpack to import to torch - # NOTE: later we will support to numpy, jax and mxnet. - dlpack = K.cvtensor_to_dlpack(cv_tensor) - th_tensor = torch.utils.dlpack.from_dlpack(dlpack) + # convert to dlpack to import to torch and numpy + # NOTE: later we will support to jax and mxnet. + th_tensor = torch.utils.dlpack.from_dlpack(cv_tensor) + np_tensor = np._from_dlpack(cv_tensor) assert th_tensor.shape == (195, 258, 3) + assert np_tensor.shape == (195, 258, 3) ``` ## TODO diff --git a/src/dlpack_py.rs b/src/dlpack_py.rs index db9bc5cd..73e9dba5 100644 --- a/src/dlpack_py.rs +++ b/src/dlpack_py.rs @@ -24,17 +24,7 @@ unsafe extern "C" fn destructor(o: *mut pyo3::ffi::PyObject) { // println!("Delete by Python"); } -unsafe extern "C" fn deleter(x: *mut dlpack::DLManagedTensor) { - // println!("DLManagedTensor deleter"); - - let ctx = (*x).manager_ctx as *mut cv::Tensor; - ctx.drop_in_place(); - (*x).dl_tensor.shape.drop_in_place(); - (*x).dl_tensor.strides.drop_in_place(); - x.drop_in_place(); -} - -fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor { +pub fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor { dlpack::DLTensor { data: x.data.as_ptr() as *mut c_void, device: dlpack::DLDevice { @@ -54,17 +44,8 @@ fn cvtensor_to_dltensor(x: &cv::Tensor) -> dlpack::DLTensor { } #[pyfunction] -pub fn cvtensor_to_dlpack(x: cv::Tensor) -> PyResult<*mut pyo3::ffi::PyObject> { - let tensor_bx = Box::new(x); - let dl_tensor = cvtensor_to_dltensor(&tensor_bx); - - // create dlpack managed tensor - let dlm_tensor = dlpack::DLManagedTensor { - dl_tensor, - manager_ctx: Box::into_raw(tensor_bx) as *mut c_void, - deleter: Some(deleter), - }; - +pub fn cvtensor_to_dlpack(x: &cv::Tensor) -> PyResult<*mut pyo3::ffi::PyObject> { + let dlm_tensor: dlpack::DLManagedTensor = x.to_dlpack(); let dlm_tensor_bx = Box::new(dlm_tensor); let name = CString::new("dltensor").unwrap(); diff --git a/src/tensor.rs b/src/tensor.rs index 001c3195..bf2aad58 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,6 +1,20 @@ pub mod cv { + use crate::dlpack; + use crate::dlpack_py::{cvtensor_to_dlpack, cvtensor_to_dltensor}; use pyo3::prelude::*; + use std::ffi::c_void; + + // in our case we don not want to delete the data + unsafe extern "C" fn deleter(_: *mut dlpack::DLManagedTensor) { + // println!("DLManagedTensor deleter"); + + //let ctx = (*x).manager_ctx as *mut Tensor; + //ctx.drop_in_place(); + //(*x).dl_tensor.shape.drop_in_place(); + //(*x).dl_tensor.strides.drop_in_place(); + //x.drop_in_place(); + } fn get_strides_from_shape(shape: &[i64]) -> Vec { let mut strides = vec![0i64; shape.len()]; @@ -37,6 +51,33 @@ pub mod cv { strides, } } + + #[pyo3(name = "__dlpack__")] + pub fn to_dlpack_py(&self) -> PyResult<*mut pyo3::ffi::PyObject> { + cvtensor_to_dlpack(self) + } + + #[pyo3(name = "__dlpack_device__")] + pub fn to_dlpack_device_py(&self) -> (u32, i32) { + let tensor_bx = Box::new(self); + let dl_tensor = cvtensor_to_dltensor(&tensor_bx); + (dl_tensor.device.device_type, dl_tensor.device.device_id) + } + } + + impl Tensor { + pub fn to_dlpack(&self) -> dlpack::DLManagedTensor { + let tensor_bx = Box::new(self); + let dl_tensor = cvtensor_to_dltensor(&tensor_bx); + + // create dlpack managed tensor + + dlpack::DLManagedTensor { + dl_tensor, + manager_ctx: Box::into_raw(tensor_bx) as *mut c_void, + deleter: Some(deleter), + } + } } } // namespace cv diff --git a/test/test_io.py b/test/test_io.py index 8c66f91d..1121918d 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -20,11 +20,13 @@ def test_read_image_jpeg(): th_tensor = torch.utils.dlpack.from_dlpack(dlpack) assert th_tensor.shape == (195, 258, 3) - # TODO: needs to be fixed + # test __dlpack__() + torch.testing.assert_close( + th_tensor, torch.utils.dlpack.from_dlpack(cv_tensor)) + # convert to dlpack to import to numpy - #dlpack = K.cvtensor_to_dlpack(cv_tensor) - #np_array = np._from_dlpack(dlpack) - #assert np_array.shape == (195, 258, 3) + np_array = np._from_dlpack(cv_tensor) + assert np_array.shape == (195, 258, 3) def test_read_image_rs(): # load an image with image-rs @@ -35,4 +37,4 @@ def test_read_image_rs(): # convert to dlpack to import to torch dlpack = K.cvtensor_to_dlpack(cv_tensor) th_tensor = torch.utils.dlpack.from_dlpack(dlpack) - assert th_tensor.shape == (195, 258, 3) \ No newline at end of file + assert th_tensor.shape == (195, 258, 3) diff --git a/test/test_tensor.py b/test/test_tensor.py index 18e81be8..908461c3 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -2,6 +2,7 @@ from kornia_rs import Tensor as cvTensor import torch +import numpy as np def test_smoke(): # dumy test @@ -12,7 +13,22 @@ def test_smoke(): assert len(data) == len(cv_tensor.data) assert cv_tensor.strides == [6, 3, 1] - # to dlpack / torch +def test_conversions(): + H, W, C = 2, 2, 3 + data = [i for i in range(H * W * C)] + cv_tensor = cvTensor([H, W, C], data) + + # to dlpack / torch / numpy dlpack = K.cvtensor_to_dlpack(cv_tensor) th_tensor = torch.utils.dlpack.from_dlpack(dlpack) assert [x for x in th_tensor.shape] == cv_tensor.shape + +def test_conversions2(): + H, W, C = 2, 2, 3 + data = [i for i in range(H * W * C)] + cv_tensor = cvTensor([H, W, C], data) + + # to dlpack / torch / numpy + th_tensor = torch.utils.dlpack.from_dlpack(cv_tensor) + np_array = np._from_dlpack(cv_tensor) + np.testing.assert_array_equal(np_array, th_tensor.numpy()) \ No newline at end of file