-
Notifications
You must be signed in to change notification settings - Fork 7
/
test_io.py
40 lines (30 loc) · 1.2 KB
/
test_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pathlib import Path
import kornia_rs as K
from kornia_rs import Tensor as cvTensor
import torch
import numpy as np
DATA_DIR = Path(__file__).parent / "data"
def test_read_image_jpeg():
# load an image with libjpeg-turbo
img_path: Path = DATA_DIR / "dog.jpeg"
cv_tensor: cvTensor = K.read_image_jpeg(str(img_path.absolute()))
assert cv_tensor.shape == [195, 258, 3]
# convert to dlpack to import to torch
dlpack = K.cvtensor_to_dlpack(cv_tensor)
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
assert th_tensor.shape == (195, 258, 3)
# test __dlpack__()
torch.testing.assert_close(
th_tensor, torch.utils.dlpack.from_dlpack(cv_tensor))
# convert to dlpack to import to numpy
np_array = np._from_dlpack(cv_tensor)
assert np_array.shape == (195, 258, 3)
def test_read_image_rs():
# load an image with image-rs
img_path: Path = DATA_DIR / "dog.jpeg"
cv_tensor: cvTensor = K.read_image_rs(str(img_path.absolute()))
assert cv_tensor.shape == [195, 258, 3]
# convert to dlpack to import to torch
dlpack = K.cvtensor_to_dlpack(cv_tensor)
th_tensor = torch.utils.dlpack.from_dlpack(dlpack)
assert th_tensor.shape == (195, 258, 3)