From 75839517b0c8087a85a3b29695f58e1ee63bdcc0 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 4 Oct 2022 02:05:55 +0100 Subject: [PATCH 1/8] Remove test related to prototype --- test/builtin_dataset_mocks.py | 1568 --------------- test/prototype_common_utils.py | 529 ----- test/prototype_transforms_dispatcher_infos.py | 259 --- test/prototype_transforms_kernel_infos.py | 1594 --------------- test/test_prototype_datasets_builtin.py | 220 -- test/test_prototype_datasets_utils.py | 302 --- test/test_prototype_features.py | 113 -- test/test_prototype_models.py | 84 - test/test_prototype_transforms.py | 1780 ----------------- test/test_prototype_transforms_consistency.py | 1097 ---------- test/test_prototype_transforms_functional.py | 956 --------- test/test_prototype_transforms_utils.py | 83 - 12 files changed, 8585 deletions(-) delete mode 100644 test/builtin_dataset_mocks.py delete mode 100644 test/prototype_common_utils.py delete mode 100644 test/prototype_transforms_dispatcher_infos.py delete mode 100644 test/prototype_transforms_kernel_infos.py delete mode 100644 test/test_prototype_datasets_builtin.py delete mode 100644 test/test_prototype_datasets_utils.py delete mode 100644 test/test_prototype_features.py delete mode 100644 test/test_prototype_models.py delete mode 100644 test/test_prototype_transforms.py delete mode 100644 test/test_prototype_transforms_consistency.py delete mode 100644 test/test_prototype_transforms_functional.py delete mode 100644 test/test_prototype_transforms_utils.py diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py deleted file mode 100644 index 8c5484a2823..00000000000 --- a/test/builtin_dataset_mocks.py +++ /dev/null @@ -1,1568 +0,0 @@ -import bz2 -import collections.abc -import csv -import functools -import gzip -import io -import itertools -import json -import lzma -import pathlib -import pickle -import random -import shutil -import unittest.mock -import warnings -import xml.etree.ElementTree as ET -from collections import Counter, defaultdict - -import numpy as np -import pytest -import torch -from datasets_utils import combinations_grid, create_image_file, create_image_folder, make_tar, make_zip -from torch.nn.functional import one_hot -from torch.testing import make_tensor as _make_tensor -from torchvision.prototype import datasets - -make_tensor = functools.partial(_make_tensor, device="cpu") -make_scalar = functools.partial(make_tensor, ()) - - -__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"] - - -class DatasetMock: - def __init__(self, name, *, mock_data_fn, configs): - # FIXME: error handling for unknown names - self.name = name - self.mock_data_fn = mock_data_fn - self.configs = configs - - def _parse_mock_info(self, mock_info): - if mock_info is None: - raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " - f"integer indicating the number of samples for the current `config`." - ) - elif isinstance(mock_info, int): - mock_info = dict(num_samples=mock_info) - elif not isinstance(mock_info, dict): - raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned a {type(mock_info)}. The returned object " - f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no " - f"additional information is required for specific tests, the number of samples can also be returned as " - f"an integer." - ) - elif "num_samples" not in mock_info: - raise pytest.UsageError( - f"The dictionary returned by the mock data function for dataset '{self.name}' has to contain a " - f"`'num_samples'` entry indicating the number of samples." - ) - - return mock_info - - def load(self, config): - # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in - # test/test_prototype_builtin_datasets.py - root = pathlib.Path(datasets.home()) / self.name - # We cannot place the mock data upfront in `root`. Loading a dataset calls `OnlineResource.load`. In turn, - # this will only download **and** preprocess if the file is not present. In other words, if we already place - # the file in `root` before the resource is loaded, we are effectively skipping the preprocessing. - # To avoid that we first place the mock data in a temporary directory and patch the download logic to move it to - # `root` only when it is requested. - tmp_mock_data_folder = root / "__mock__" - tmp_mock_data_folder.mkdir(parents=True) - - mock_info = self._parse_mock_info(self.mock_data_fn(tmp_mock_data_folder, config)) - - def patched_download(resource, root, **kwargs): - src = tmp_mock_data_folder / resource.file_name - if not src.exists(): - raise pytest.UsageError( - f"Dataset '{self.name}' requires the file {resource.file_name} for {config}" - f"but it was not created by the mock data function." - ) - - dst = root / resource.file_name - shutil.move(str(src), str(root)) - - return dst - - with unittest.mock.patch( - "torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=patched_download - ): - dataset = datasets.load(self.name, **config) - - extra_files = list(tmp_mock_data_folder.glob("**/*")) - if extra_files: - raise pytest.UsageError( - ( - f"Dataset '{self.name}' created the following files for {config} in the mock data function, " - f"but they were not loaded:\n\n" - ) - + "\n".join(str(file.relative_to(tmp_mock_data_folder)) for file in extra_files) - ) - - tmp_mock_data_folder.rmdir() - - return dataset, mock_info - - -def config_id(name, config): - parts = [name] - for name, value in config.items(): - if isinstance(value, bool): - part = ("" if value else "no_") + name - else: - part = str(value) - parts.append(part) - return "-".join(parts) - - -def parametrize_dataset_mocks(*dataset_mocks, marks=None): - mocks = {} - for mock in dataset_mocks: - if isinstance(mock, DatasetMock): - mocks[mock.name] = mock - elif isinstance(mock, collections.abc.Mapping): - mocks.update(mock) - else: - raise pytest.UsageError( - f"The positional arguments passed to `parametrize_dataset_mocks` can either be a `DatasetMock`, " - f"a sequence of `DatasetMock`'s, or a mapping of names to `DatasetMock`'s, " - f"but got {mock} instead." - ) - dataset_mocks = mocks - - if marks is None: - marks = {} - elif not isinstance(marks, collections.abc.Mapping): - raise pytest.UsageError() - - return pytest.mark.parametrize( - ("dataset_mock", "config"), - [ - pytest.param(dataset_mock, config, id=config_id(name, config), marks=marks.get(name, ())) - for name, dataset_mock in dataset_mocks.items() - for config in dataset_mock.configs - ], - ) - - -DATASET_MOCKS = {} - - -def register_mock(name=None, *, configs): - def wrapper(mock_data_fn): - nonlocal name - if name is None: - name = mock_data_fn.__name__ - DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs) - - return mock_data_fn - - return wrapper - - -class MNISTMockData: - _DTYPES_ID = { - torch.uint8: 8, - torch.int8: 9, - torch.int16: 11, - torch.int32: 12, - torch.float32: 13, - torch.float64: 14, - } - - @classmethod - def _magic(cls, dtype, ndim): - return cls._DTYPES_ID[dtype] * 256 + ndim + 1 - - @staticmethod - def _encode(t): - return torch.tensor(t, dtype=torch.int32).numpy().tobytes()[::-1] - - @staticmethod - def _big_endian_dtype(dtype): - np_dtype = getattr(np, str(dtype).replace("torch.", ""))().dtype - return np.dtype(f">{np_dtype.kind}{np_dtype.itemsize}") - - @classmethod - def _create_binary_file(cls, root, filename, *, num_samples, shape, dtype, compressor, low=0, high): - with compressor(root / filename, "wb") as fh: - for meta in (cls._magic(dtype, len(shape)), num_samples, *shape): - fh.write(cls._encode(meta)) - - data = make_tensor((num_samples, *shape), dtype=dtype, low=low, high=high) - - fh.write(data.numpy().astype(cls._big_endian_dtype(dtype)).tobytes()) - - @classmethod - def generate( - cls, - root, - *, - num_categories, - num_samples=None, - images_file, - labels_file, - image_size=(28, 28), - image_dtype=torch.uint8, - label_size=(), - label_dtype=torch.uint8, - compressor=None, - ): - if num_samples is None: - num_samples = num_categories - if compressor is None: - compressor = gzip.open - - cls._create_binary_file( - root, - images_file, - num_samples=num_samples, - shape=image_size, - dtype=image_dtype, - compressor=compressor, - high=float("inf"), - ) - cls._create_binary_file( - root, - labels_file, - num_samples=num_samples, - shape=label_size, - dtype=label_dtype, - compressor=compressor, - high=num_categories, - ) - - return num_samples - - -def mnist(root, config): - prefix = "train" if config["split"] == "train" else "t10k" - return MNISTMockData.generate( - root, - num_categories=10, - images_file=f"{prefix}-images-idx3-ubyte.gz", - labels_file=f"{prefix}-labels-idx1-ubyte.gz", - ) - - -DATASET_MOCKS.update( - { - name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) - for name in ["mnist", "fashionmnist", "kmnist"] - } -) - - -@register_mock( - configs=combinations_grid( - split=("train", "test"), - image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), - ) -) -def emnist(root, config): - num_samples_map = {} - file_names = set() - for split, image_set in itertools.product( - ("train", "test"), - ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), - ): - prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - file_names.update({images_file, labels_file}) - num_samples_map[(split, image_set)] = MNISTMockData.generate( - root, - # The image sets that merge some lower case letters in their respective upper case variant, still use dense - # labels in the data files. Thus, num_categories != len(categories) there. - num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, - images_file=images_file, - labels_file=labels_file, - ) - - make_zip(root, "emnist-gzip.zip", *file_names) - - return num_samples_map[(config["split"], config["image_set"])] - - -@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) -def qmnist(root, config): - num_categories = 10 - if config["split"] == "train": - num_samples = num_samples_gen = num_categories + 2 - prefix = "qmnist-train" - suffix = ".gz" - compressor = gzip.open - elif config["split"].startswith("test"): - # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create - # more than 10000 images for the dataset to not be empty. - num_samples_gen = 10001 - num_samples = { - "test": num_samples_gen, - "test10k": min(num_samples_gen, 10_000), - "test50k": num_samples_gen - 10_000, - }[config["split"]] - prefix = "qmnist-test" - suffix = ".gz" - compressor = gzip.open - else: # config["split"] == "nist" - num_samples = num_samples_gen = num_categories + 3 - prefix = "xnist" - suffix = ".xz" - compressor = lzma.open - - MNISTMockData.generate( - root, - num_categories=num_categories, - num_samples=num_samples_gen, - images_file=f"{prefix}-images-idx3-ubyte{suffix}", - labels_file=f"{prefix}-labels-idx2-int{suffix}", - label_size=(8,), - label_dtype=torch.int32, - compressor=compressor, - ) - return num_samples - - -class CIFARMockData: - NUM_PIXELS = 32 * 32 * 3 - - @classmethod - def _create_batch_file(cls, root, name, *, num_categories, labels_key, num_samples=1): - content = { - "data": make_tensor((num_samples, cls.NUM_PIXELS), dtype=torch.uint8).numpy(), - labels_key: torch.randint(0, num_categories, size=(num_samples,)).tolist(), - } - with open(pathlib.Path(root) / name, "wb") as fh: - pickle.dump(content, fh) - - @classmethod - def generate( - cls, - root, - name, - *, - folder, - train_files, - test_files, - num_categories, - labels_key, - ): - folder = root / folder - folder.mkdir() - files = (*train_files, *test_files) - for file in files: - cls._create_batch_file( - folder, - file, - num_categories=num_categories, - labels_key=labels_key, - ) - - make_tar(root, name, folder, compression="gz") - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def cifar10(root, config): - train_files = [f"data_batch_{idx}" for idx in range(1, 6)] - test_files = ["test_batch"] - - CIFARMockData.generate( - root=root, - name="cifar-10-python.tar.gz", - folder=pathlib.Path("cifar-10-batches-py"), - train_files=train_files, - test_files=test_files, - num_categories=10, - labels_key="labels", - ) - - return len(train_files if config["split"] == "train" else test_files) - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def cifar100(root, config): - train_files = ["train"] - test_files = ["test"] - - CIFARMockData.generate( - root=root, - name="cifar-100-python.tar.gz", - folder=pathlib.Path("cifar-100-python"), - train_files=train_files, - test_files=test_files, - num_categories=100, - labels_key="fine_labels", - ) - - return len(train_files if config["split"] == "train" else test_files) - - -@register_mock(configs=[dict()]) -def caltech101(root, config): - def create_ann_file(root, name): - import scipy.io - - box_coord = make_tensor((1, 4), dtype=torch.int32, low=0).numpy().astype(np.uint16) - obj_contour = make_tensor((2, int(torch.randint(3, 6, size=()))), dtype=torch.float64, low=0).numpy() - - scipy.io.savemat(str(pathlib.Path(root) / name), dict(box_coord=box_coord, obj_contour=obj_contour)) - - def create_ann_folder(root, name, file_name_fn, num_examples): - root = pathlib.Path(root) / name - root.mkdir(parents=True) - - for idx in range(num_examples): - create_ann_file(root, file_name_fn(idx)) - - images_root = root / "101_ObjectCategories" - anns_root = root / "Annotations" - - image_category_map = { - "Faces": "Faces_2", - "Faces_easy": "Faces_3", - "Motorbikes": "Motorbikes_16", - "airplanes": "Airplanes_Side_2", - } - - categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"] - - num_images_per_category = 2 - for category in categories: - create_image_folder( - root=images_root, - name=category, - file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg", - num_examples=num_images_per_category, - ) - create_ann_folder( - root=anns_root, - name=image_category_map.get(category, category), - file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat", - num_examples=num_images_per_category, - ) - - (images_root / "BACKGROUND_Goodle").mkdir() - make_tar(root, f"{images_root.name}.tar.gz", images_root, compression="gz") - - make_tar(root, f"{anns_root.name}.tar", anns_root) - - return num_images_per_category * len(categories) - - -@register_mock(configs=[dict()]) -def caltech256(root, config): - dir = root / "256_ObjectCategories" - num_images_per_category = 2 - - categories = [ - (1, "ak47"), - (127, "laptop-101"), - (198, "spider"), - (257, "clutter"), - ] - - for category_idx, category in categories: - files = create_image_folder( - dir, - name=f"{category_idx:03d}.{category}", - file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg", - num_examples=num_images_per_category, - ) - if category == "spider": - open(files[0].parent / "RENAME2", "w").close() - - make_tar(root, f"{dir.name}.tar", dir) - - return num_images_per_category * len(categories) - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def imagenet(root, config): - from scipy.io import savemat - - info = datasets.info("imagenet") - - if config["split"] == "train": - num_samples = len(info["wnids"]) - archive_name = "ILSVRC2012_img_train.tar" - - files = [] - for wnid in info["wnids"]: - create_image_folder( - root=root, - name=wnid, - file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG", - num_examples=1, - ) - files.append(make_tar(root, f"{wnid}.tar")) - elif config["split"] == "val": - num_samples = 3 - archive_name = "ILSVRC2012_img_val.tar" - files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - - devkit_root = root / "ILSVRC2012_devkit_t12" - data_root = devkit_root / "data" - data_root.mkdir(parents=True) - - with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist(): - file.write(f"{label}\n") - - num_children = 0 - synsets = [ - (idx, wnid, category, "", num_children, [], 0, 0) - for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1) - ] - num_children = 1 - synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) - with warnings.catch_warnings(): - # The warning is not for savemat, but rather for some internals savemet is using - warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) - savemat(data_root / "meta.mat", dict(synsets=synsets)) - - make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # config["split"] == "test" - num_samples = 5 - archive_name = "ILSVRC2012_img_test_v10102019.tar" - files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - - make_tar(root, archive_name, *files) - - return num_samples - - -class CocoMockData: - @classmethod - def _make_annotations_json( - cls, - root, - name, - *, - images_meta, - fn, - ): - num_anns_per_image = torch.randint(1, 5, (len(images_meta),)) - num_anns_total = int(num_anns_per_image.sum()) - ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)]) - - anns_meta = [] - for image_meta, num_anns in zip(images_meta, num_anns_per_image): - for _ in range(num_anns): - ann_id = int(next(ann_ids_iter)) - anns_meta.append(dict(fn(ann_id, image_meta), id=ann_id, image_id=image_meta["id"])) - anns_meta.sort(key=lambda ann: ann["id"]) - - with open(root / name, "w") as file: - json.dump(dict(images=images_meta, annotations=anns_meta), file) - - return num_anns_per_image - - @staticmethod - def _make_instances_data(ann_id, image_meta): - def make_rle_segmentation(): - height, width = image_meta["height"], image_meta["width"] - numel = height * width - counts = [] - while sum(counts) <= numel: - counts.append(int(torch.randint(5, 8, ()))) - if sum(counts) > numel: - counts[-1] -= sum(counts) - numel - return dict(counts=counts, size=[height, width]) - - return dict( - segmentation=make_rle_segmentation(), - bbox=make_tensor((4,), dtype=torch.float32, low=0).tolist(), - iscrowd=True, - area=float(make_scalar(dtype=torch.float32)), - category_id=int(make_scalar(dtype=torch.int64)), - ) - - @staticmethod - def _make_captions_data(ann_id, image_meta): - return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.") - - @classmethod - def _make_annotations(cls, root, name, *, images_meta): - num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64) - for annotations, fn in ( - ("instances", cls._make_instances_data), - ("captions", cls._make_captions_data), - ): - num_anns_per_image += cls._make_annotations_json( - root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn - ) - - return int(num_anns_per_image.sum()) - - @classmethod - def generate( - cls, - root, - *, - split, - year, - num_samples, - ): - annotations_dir = root / "annotations" - annotations_dir.mkdir() - - for split_ in ("train", "val"): - config_name = f"{split_}{year}" - - images_meta = [ - dict( - file_name=f"{idx:012d}.jpg", - id=idx, - width=width, - height=height, - ) - for idx, (height, width) in enumerate( - torch.randint(3, 11, size=(num_samples, 2), dtype=torch.int).tolist() - ) - ] - - if split_ == split: - create_image_folder( - root, - config_name, - file_name_fn=lambda idx: images_meta[idx]["file_name"], - num_examples=num_samples, - size=lambda idx: (3, images_meta[idx]["height"], images_meta[idx]["width"]), - ) - make_zip(root, f"{config_name}.zip") - - cls._make_annotations( - annotations_dir, - config_name, - images_meta=images_meta, - ) - - make_zip(root, f"annotations_trainval{year}.zip", annotations_dir) - - return num_samples - - -@register_mock( - configs=combinations_grid( - split=("train", "val"), - year=("2017", "2014"), - annotations=("instances", "captions", None), - ) -) -def coco(root, config): - return CocoMockData.generate(root, split=config["split"], year=config["year"], num_samples=5) - - -class SBDMockData: - _NUM_CATEGORIES = 20 - - @classmethod - def _make_split_files(cls, root_map): - ids_map = { - split: [f"2008_{idx:06d}" for idx in idcs] - for split, idcs in ( - ("train", [0, 1, 2]), - ("train_noval", [0, 2]), - ("val", [3]), - ) - } - - for split, ids in ids_map.items(): - with open(root_map[split] / f"{split}.txt", "w") as fh: - fh.writelines(f"{id}\n" for id in ids) - - return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - - @classmethod - def _make_anns_folder(cls, root, name, ids): - from scipy.io import savemat - - anns_folder = root / name - anns_folder.mkdir() - - sizes = torch.randint(1, 9, size=(len(ids), 2)).tolist() - for id, size in zip(ids, sizes): - savemat( - anns_folder / f"{id}.mat", - { - "GTcls": { - "Boundaries": cls._make_boundaries(size), - "Segmentation": cls._make_segmentation(size), - } - }, - ) - return sizes - - @classmethod - def _make_boundaries(cls, size): - from scipy.sparse import csc_matrix - - return [ - [csc_matrix(torch.randint(0, 2, size=size, dtype=torch.uint8).numpy())] for _ in range(cls._NUM_CATEGORIES) - ] - - @classmethod - def _make_segmentation(cls, size): - return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy() - - @classmethod - def generate(cls, root): - archive_folder = root / "benchmark_RELEASE" - dataset_folder = archive_folder / "dataset" - dataset_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root})) - sizes = cls._make_anns_folder(dataset_folder, "cls", ids) - create_image_folder( - dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx] - ) - - make_tar(root, "benchmark.tgz", archive_folder, compression="gz") - - return num_samples_map - - -@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) -def sbd(root, config): - return SBDMockData.generate(root)[config["split"]] - - -@register_mock(configs=[dict()]) -def semeion(root, config): - num_samples = 3 - num_categories = 10 - - images = torch.rand(num_samples, 256) - labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) - with open(root / "semeion.data", "w") as fh: - for image, one_hot_label in zip(images, labels): - image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image]) - labels_columns = " ".join([str(label.item()) for label in one_hot_label]) - fh.write(f"{image_columns} {labels_columns} \n") - - return num_samples - - -class VOCMockData: - _TRAIN_VAL_FILE_NAMES = { - "2007": "VOCtrainval_06-Nov-2007.tar", - "2008": "VOCtrainval_14-Jul-2008.tar", - "2009": "VOCtrainval_11-May-2009.tar", - "2010": "VOCtrainval_03-May-2010.tar", - "2011": "VOCtrainval_25-May-2011.tar", - "2012": "VOCtrainval_11-May-2012.tar", - } - _TEST_FILE_NAMES = { - "2007": "VOCtest_06-Nov-2007.tar", - } - - @classmethod - def _make_split_files(cls, root, *, year, trainval): - split_folder = root / "ImageSets" - - if trainval: - idcs_map = { - "train": [0, 1, 2], - "val": [3, 4], - } - idcs_map["trainval"] = [*idcs_map["train"], *idcs_map["val"]] - else: - idcs_map = { - "test": [5], - } - ids_map = {split: [f"{year}_{idx:06d}" for idx in idcs] for split, idcs in idcs_map.items()} - - for task_sub_folder in ("Main", "Segmentation"): - task_folder = split_folder / task_sub_folder - task_folder.mkdir(parents=True, exist_ok=True) - for split, ids in ids_map.items(): - with open(task_folder / f"{split}.txt", "w") as fh: - fh.writelines(f"{id}\n" for id in ids) - - return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - - @classmethod - def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): - folder = root / name - folder.mkdir(parents=True, exist_ok=True) - - for idx in range(num_examples): - cls._make_detection_ann_file(folder, file_name_fn(idx)) - - @classmethod - def _make_detection_ann_file(cls, root, name): - def add_child(parent, name, text=None): - child = ET.SubElement(parent, name) - child.text = str(text) - return child - - def add_name(obj, name="dog"): - add_child(obj, "name", name) - - def add_size(obj): - obj = add_child(obj, "size") - size = {"width": 0, "height": 0, "depth": 3} - for name, text in size.items(): - add_child(obj, name, text) - - def add_bndbox(obj): - obj = add_child(obj, "bndbox") - bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4} - for name, text in bndbox.items(): - add_child(obj, name, text) - - annotation = ET.Element("annotation") - add_size(annotation) - obj = add_child(annotation, "object") - add_name(obj) - add_bndbox(obj) - - with open(root / name, "wb") as fh: - fh.write(ET.tostring(annotation)) - - @classmethod - def generate(cls, root, *, year, trainval): - archive_folder = root - if year == "2011": - archive_folder = root / "TrainVal" - data_folder = archive_folder / "VOCdevkit" - else: - archive_folder = data_folder = root / "VOCdevkit" - data_folder = data_folder / f"VOC{year}" - data_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval) - for make_folder_fn, name, suffix in [ - (create_image_folder, "JPEGImages", ".jpg"), - (create_image_folder, "SegmentationClass", ".png"), - (cls._make_detection_anns_folder, "Annotations", ".xml"), - ]: - make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) - make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], archive_folder) - - return num_samples_map - - -@register_mock( - configs=[ - *combinations_grid( - split=("train", "val", "trainval"), - year=("2007", "2008", "2009", "2010", "2011", "2012"), - task=("detection", "segmentation"), - ), - *combinations_grid( - split=("test",), - year=("2007",), - task=("detection", "segmentation"), - ), - ], -) -def voc(root, config): - trainval = config["split"] != "test" - return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] - - -class CelebAMockData: - @classmethod - def _make_ann_file(cls, root, name, data, *, field_names=None): - with open(root / name, "w") as file: - if field_names: - file.write(f"{len(data)}\r\n") - file.write(" ".join(field_names) + "\r\n") - file.writelines(" ".join(str(item) for item in row) + "\r\n" for row in data) - - _SPLIT_TO_IDX = { - "train": 0, - "val": 1, - "test": 2, - } - - @classmethod - def _make_split_file(cls, root): - num_samples_map = {"train": 4, "val": 3, "test": 2} - - data = [ - (f"{idx:06d}.jpg", cls._SPLIT_TO_IDX[split]) - for split, num_samples in num_samples_map.items() - for idx in range(num_samples) - ] - cls._make_ann_file(root, "list_eval_partition.txt", data) - - image_file_names, _ = zip(*data) - return image_file_names, num_samples_map - - @classmethod - def _make_identity_file(cls, root, image_file_names): - cls._make_ann_file( - root, "identity_CelebA.txt", [(name, int(make_scalar(low=1, dtype=torch.int))) for name in image_file_names] - ) - - @classmethod - def _make_attributes_file(cls, root, image_file_names): - field_names = ("5_o_Clock_Shadow", "Young") - data = [ - [name, *[" 1" if attr else "-1" for attr in make_tensor((len(field_names),), dtype=torch.bool)]] - for name in image_file_names - ] - cls._make_ann_file(root, "list_attr_celeba.txt", data, field_names=(*field_names, "")) - - @classmethod - def _make_bounding_boxes_file(cls, root, image_file_names): - field_names = ("image_id", "x_1", "y_1", "width", "height") - data = [ - [f"{name} ", *[f"{coord:3d}" for coord in make_tensor((4,), low=0, dtype=torch.int).tolist()]] - for name in image_file_names - ] - cls._make_ann_file(root, "list_bbox_celeba.txt", data, field_names=field_names) - - @classmethod - def _make_landmarks_file(cls, root, image_file_names): - field_names = ("lefteye_x", "lefteye_y", "rightmouth_x", "rightmouth_y") - data = [ - [ - name, - *[ - f"{coord:4d}" if idx else coord - for idx, coord in enumerate(make_tensor((len(field_names),), low=0, dtype=torch.int).tolist()) - ], - ] - for name in image_file_names - ] - cls._make_ann_file(root, "list_landmarks_align_celeba.txt", data, field_names=field_names) - - @classmethod - def generate(cls, root): - image_file_names, num_samples_map = cls._make_split_file(root) - - image_files = create_image_folder( - root, "img_align_celeba", file_name_fn=lambda idx: image_file_names[idx], num_examples=len(image_file_names) - ) - make_zip(root, image_files[0].parent.with_suffix(".zip").name) - - for make_ann_file_fn in ( - cls._make_identity_file, - cls._make_attributes_file, - cls._make_bounding_boxes_file, - cls._make_landmarks_file, - ): - make_ann_file_fn(root, image_file_names) - - return num_samples_map - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def celeba(root, config): - return CelebAMockData.generate(root)[config["split"]] - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def country211(root, config): - split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"]) - split_folder.mkdir(parents=True, exist_ok=True) - - num_examples = { - "train": 3, - "val": 4, - "test": 5, - }[config["split"]] - - classes = ("AD", "BS", "GR") - for cls in classes: - create_image_folder( - split_folder, - name=cls, - file_name_fn=lambda idx: f"{idx}.jpg", - num_examples=num_examples, - ) - make_tar(root, f"{split_folder.parent.name}.tgz", split_folder.parent, compression="gz") - return num_examples * len(classes) - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def food101(root, config): - data_folder = root / "food-101" - - num_images_per_class = 3 - image_folder = data_folder / "images" - categories = ["apple_pie", "baby_back_ribs", "waffles"] - image_ids = [] - for category in categories: - image_files = create_image_folder( - image_folder, - category, - file_name_fn=lambda idx: f"{idx:04d}.jpg", - num_examples=num_images_per_class, - ) - image_ids.extend(path.relative_to(path.parents[1]).with_suffix("").as_posix() for path in image_files) - - meta_folder = data_folder / "meta" - meta_folder.mkdir() - - with open(meta_folder / "classes.txt", "w") as file: - for category in categories: - file.write(f"{category}\n") - - splits = ["train", "test"] - num_samples_map = {} - for offset, split in enumerate(splits): - image_ids_in_split = image_ids[offset :: len(splits)] - num_samples_map[split] = len(image_ids_in_split) - with open(meta_folder / f"{split}.txt", "w") as file: - for image_id in image_ids_in_split: - file.write(f"{image_id}\n") - - make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") - - return num_samples_map[config["split"]] - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) -def dtd(root, config): - data_folder = root / "dtd" - - num_images_per_class = 3 - image_folder = data_folder / "images" - categories = {"banded", "marbled", "zigzagged"} - image_ids_per_category = { - category: [ - str(path.relative_to(path.parents[1]).as_posix()) - for path in create_image_folder( - image_folder, - category, - file_name_fn=lambda idx: f"{category}_{idx:04d}.jpg", - num_examples=num_images_per_class, - ) - ] - for category in categories - } - - meta_folder = data_folder / "labels" - meta_folder.mkdir() - - with open(meta_folder / "labels_joint_anno.txt", "w") as file: - for cls, image_ids in image_ids_per_category.items(): - for image_id in image_ids: - joint_categories = random.choices( - list(categories - {cls}), k=int(torch.randint(len(categories) - 1, ())) - ) - file.write(" ".join([image_id, *sorted([cls, *joint_categories])]) + "\n") - - image_ids = list(itertools.chain(*image_ids_per_category.values())) - splits = ("train", "val", "test") - num_samples_map = {} - for fold in range(1, 11): - random.shuffle(image_ids) - for offset, split in enumerate(splits): - image_ids_in_config = image_ids[offset :: len(splits)] - with open(meta_folder / f"{split}{fold}.txt", "w") as file: - file.write("\n".join(image_ids_in_config) + "\n") - - num_samples_map[(split, fold)] = len(image_ids_in_config) - - make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") - - return num_samples_map[config["split"], config["fold"]] - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def fer2013(root, config): - split = config["split"] - num_samples = 5 if split == "train" else 3 - - path = root / f"{split}.csv" - with open(path, "w", newline="") as file: - field_names = ["emotion"] if split == "train" else [] - field_names.append("pixels") - - file.write(",".join(field_names) + "\n") - - writer = csv.DictWriter(file, fieldnames=field_names, quotechar='"', quoting=csv.QUOTE_NONNUMERIC) - for _ in range(num_samples): - rowdict = { - "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) - } - if split == "train": - rowdict["emotion"] = int(torch.randint(7, ())) - writer.writerow(rowdict) - - make_zip(root, f"{path.name}.zip", path) - - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def gtsrb(root, config): - num_examples_per_class = 5 if config["split"] == "train" else 3 - classes = ("00000", "00042", "00012") - num_examples = num_examples_per_class * len(classes) - - csv_columns = ["Filename", "Width", "Height", "Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2", "ClassId"] - - def _make_ann_file(path, num_examples, class_idx): - if class_idx == "random": - class_idx = torch.randint(1, len(classes) + 1, size=(1,)).item() - - with open(path, "w") as csv_file: - writer = csv.DictWriter(csv_file, fieldnames=csv_columns, delimiter=";") - writer.writeheader() - for image_idx in range(num_examples): - writer.writerow( - { - "Filename": f"{image_idx:05d}.ppm", - "Width": torch.randint(1, 100, size=()).item(), - "Height": torch.randint(1, 100, size=()).item(), - "Roi.X1": torch.randint(1, 100, size=()).item(), - "Roi.Y1": torch.randint(1, 100, size=()).item(), - "Roi.X2": torch.randint(1, 100, size=()).item(), - "Roi.Y2": torch.randint(1, 100, size=()).item(), - "ClassId": class_idx, - } - ) - - archive_folder = root / "GTSRB" - - if config["split"] == "train": - train_folder = archive_folder / "Training" - train_folder.mkdir(parents=True) - - for class_idx in classes: - create_image_folder( - train_folder, - name=class_idx, - file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm", - num_examples=num_examples_per_class, - ) - _make_ann_file( - path=train_folder / class_idx / f"GT-{class_idx}.csv", - num_examples=num_examples_per_class, - class_idx=int(class_idx), - ) - make_zip(root, "GTSRB-Training_fixed.zip", archive_folder) - else: - test_folder = archive_folder / "Final_Test" - test_folder.mkdir(parents=True) - - create_image_folder( - test_folder, - name="Images", - file_name_fn=lambda image_idx: f"{image_idx:05d}.ppm", - num_examples=num_examples, - ) - - make_zip(root, "GTSRB_Final_Test_Images.zip", archive_folder) - - _make_ann_file( - path=root / "GT-final_test.csv", - num_examples=num_examples, - class_idx="random", - ) - - make_zip(root, "GTSRB_Final_Test_GT.zip", "GT-final_test.csv") - - return num_examples - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def clevr(root, config): - data_folder = root / "CLEVR_v1.0" - - num_samples_map = { - "train": 3, - "val": 2, - "test": 1, - } - - images_folder = data_folder / "images" - image_files = { - split: create_image_folder( - images_folder, - split, - file_name_fn=lambda idx: f"CLEVR_{split}_{idx:06d}.jpg", - num_examples=num_samples, - ) - for split, num_samples in num_samples_map.items() - } - - scenes_folder = data_folder / "scenes" - scenes_folder.mkdir() - for split in ["train", "val"]: - with open(scenes_folder / f"CLEVR_{split}_scenes.json", "w") as file: - json.dump( - { - "scenes": [ - { - "image_filename": image_file.name, - # We currently only return the number of objects in a scene. - # Thus, it is sufficient for now to only mock the number of elements. - "objects": [None] * int(torch.randint(1, 5, ())), - } - for image_file in image_files[split] - ] - }, - file, - ) - - make_zip(root, f"{data_folder.name}.zip", data_folder) - - return num_samples_map[config["split"]] - - -class OxfordIIITPetMockData: - @classmethod - def _meta_to_split_and_classification_ann(cls, meta, idx): - image_id = "_".join( - [ - *[(str.title if meta["species"] == "cat" else str.lower)(part) for part in meta["cls"].split()], - str(idx), - ] - ) - class_id = str(meta["label"] + 1) - species = "1" if meta["species"] == "cat" else "2" - breed_id = "-1" - return (image_id, class_id, species, breed_id) - - @classmethod - def generate(self, root): - classification_anns_meta = ( - dict(cls="Abyssinian", label=0, species="cat"), - dict(cls="Keeshond", label=18, species="dog"), - dict(cls="Yorkshire Terrier", label=36, species="dog"), - ) - split_and_classification_anns = [ - self._meta_to_split_and_classification_ann(meta, idx) - for meta, idx in itertools.product(classification_anns_meta, (1, 2, 10)) - ] - image_ids, *_ = zip(*split_and_classification_anns) - - image_files = create_image_folder( - root, "images", file_name_fn=lambda idx: f"{image_ids[idx]}.jpg", num_examples=len(image_ids) - ) - - anns_folder = root / "annotations" - anns_folder.mkdir() - random.shuffle(split_and_classification_anns) - splits = ("trainval", "test") - num_samples_map = {} - for offset, split in enumerate(splits): - split_and_classification_anns_in_split = split_and_classification_anns[offset :: len(splits)] - with open(anns_folder / f"{split}.txt", "w") as file: - writer = csv.writer(file, delimiter=" ") - for split_and_classification_ann in split_and_classification_anns_in_split: - writer.writerow(split_and_classification_ann) - - num_samples_map[split] = len(split_and_classification_anns_in_split) - - segmentation_files = create_image_folder( - anns_folder, "trimaps", file_name_fn=lambda idx: f"{image_ids[idx]}.png", num_examples=len(image_ids) - ) - - # The dataset has some rogue files - for path in image_files[:3]: - path.with_suffix(".mat").touch() - for path in segmentation_files: - path.with_name(f".{path.name}").touch() - - make_tar(root, "images.tar.gz", compression="gz") - make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz") - - return num_samples_map - - -@register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test"))) -def oxford_iiit_pet(root, config): - return OxfordIIITPetMockData.generate(root)[config["split"]] - - -class _CUB200MockData: - @classmethod - def _category_folder(cls, category, idx): - return f"{idx:03d}.{category}" - - @classmethod - def _file_stem(cls, category, idx): - return f"{category}_{idx:04d}" - - @classmethod - def _make_images(cls, images_folder): - image_files = [] - for category_idx, category in [ - (1, "Black_footed_Albatross"), - (100, "Brown_Pelican"), - (200, "Common_Yellowthroat"), - ]: - image_files.extend( - create_image_folder( - images_folder, - cls._category_folder(category, category_idx), - lambda image_idx: f"{cls._file_stem(category, image_idx)}.jpg", - num_examples=5, - ) - ) - - return image_files - - -class CUB2002011MockData(_CUB200MockData): - @classmethod - def _make_archive(cls, root): - archive_folder = root / "CUB_200_2011" - - images_folder = archive_folder / "images" - image_files = cls._make_images(images_folder) - image_ids = list(range(1, len(image_files) + 1)) - - with open(archive_folder / "images.txt", "w") as file: - file.write( - "\n".join( - f"{id} {path.relative_to(images_folder).as_posix()}" for id, path in zip(image_ids, image_files) - ) - ) - - split_ids = torch.randint(2, (len(image_ids),)).tolist() - counts = Counter(split_ids) - num_samples_map = {"train": counts[1], "test": counts[0]} - with open(archive_folder / "train_test_split.txt", "w") as file: - file.write("\n".join(f"{image_id} {split_id}" for image_id, split_id in zip(image_ids, split_ids))) - - with open(archive_folder / "bounding_boxes.txt", "w") as file: - file.write( - "\n".join( - " ".join( - str(item) - for item in [image_id, *make_tensor((4,), dtype=torch.int, low=0).to(torch.float).tolist()] - ) - for image_id in image_ids - ) - ) - - make_tar(root, archive_folder.with_suffix(".tgz").name, compression="gz") - - return image_files, num_samples_map - - @classmethod - def _make_segmentations(cls, root, image_files): - segmentations_folder = root / "segmentations" - for image_file in image_files: - folder = segmentations_folder.joinpath(image_file.relative_to(image_file.parents[1])) - folder.mkdir(exist_ok=True, parents=True) - create_image_file( - folder, - image_file.with_suffix(".png").name, - size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()], - ) - - make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz") - - @classmethod - def generate(cls, root): - image_files, num_samples_map = cls._make_archive(root) - cls._make_segmentations(root, image_files) - return num_samples_map - - -class CUB2002010MockData(_CUB200MockData): - @classmethod - def _make_hidden_rouge_file(cls, *files): - for file in files: - (file.parent / f"._{file.name}").touch() - - @classmethod - def _make_splits(cls, root, image_files): - split_folder = root / "lists" - split_folder.mkdir() - random.shuffle(image_files) - splits = ("train", "test") - num_samples_map = {} - for offset, split in enumerate(splits): - image_files_in_split = image_files[offset :: len(splits)] - - split_file = split_folder / f"{split}.txt" - with open(split_file, "w") as file: - file.write( - "\n".join( - sorted( - str(image_file.relative_to(image_file.parents[1]).as_posix()) - for image_file in image_files_in_split - ) - ) - ) - - cls._make_hidden_rouge_file(split_file) - num_samples_map[split] = len(image_files_in_split) - - make_tar(root, split_folder.with_suffix(".tgz").name, compression="gz") - - return num_samples_map - - @classmethod - def _make_anns(cls, root, image_files): - from scipy.io import savemat - - anns_folder = root / "annotations-mat" - for image_file in image_files: - ann_file = anns_folder / image_file.with_suffix(".mat").relative_to(image_file.parents[1]) - ann_file.parent.mkdir(parents=True, exist_ok=True) - - savemat( - ann_file, - { - "seg": torch.randint( - 256, make_tensor((2,), low=3, dtype=torch.int).tolist(), dtype=torch.uint8 - ).numpy(), - "bbox": dict( - zip(("left", "top", "right", "bottom"), make_tensor((4,), dtype=torch.uint8).tolist()) - ), - }, - ) - - readme_file = anns_folder / "README.txt" - readme_file.touch() - cls._make_hidden_rouge_file(readme_file) - - make_tar(root, "annotations.tgz", anns_folder, compression="gz") - - @classmethod - def generate(cls, root): - images_folder = root / "images" - image_files = cls._make_images(images_folder) - cls._make_hidden_rouge_file(*image_files) - make_tar(root, images_folder.with_suffix(".tgz").name, compression="gz") - - num_samples_map = cls._make_splits(root, image_files) - cls._make_anns(root, image_files) - - return num_samples_map - - -@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011"))) -def cub200(root, config): - num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root) - return num_samples_map[config["split"]] - - -@register_mock(configs=[dict()]) -def eurosat(root, config): - data_folder = root / "2750" - data_folder.mkdir(parents=True) - - num_examples_per_class = 3 - categories = ["AnnualCrop", "Forest"] - for category in categories: - create_image_folder( - root=data_folder, - name=category, - file_name_fn=lambda idx: f"{category}_{idx + 1}.jpg", - num_examples=num_examples_per_class, - ) - make_zip(root, "EuroSAT.zip", data_folder) - return len(categories) * num_examples_per_class - - -@register_mock(configs=combinations_grid(split=("train", "test", "extra"))) -def svhn(root, config): - import scipy.io as sio - - num_samples = { - "train": 2, - "test": 3, - "extra": 4, - }[config["split"]] - - sio.savemat( - root / f"{config['split']}_32x32.mat", - { - "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), - "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), - }, - ) - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def pcam(root, config): - import h5py - - num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] - - split = "valid" if config["split"] == "val" else config["split"] - - images_io = io.BytesIO() - with h5py.File(images_io, "w") as f: - f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8) - - targets_io = io.BytesIO() - with h5py.File(targets_io, "w") as f: - f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8) - - # Create .gz compressed files - images_file = root / f"camelyonpatch_level_2_split_{split}_x.h5.gz" - targets_file = root / f"camelyonpatch_level_2_split_{split}_y.h5.gz" - for compressed_file_name, uncompressed_file_io in ((images_file, images_io), (targets_file, targets_io)): - compressed_data = gzip.compress(uncompressed_file_io.getbuffer()) - with open(compressed_file_name, "wb") as compressed_file: - compressed_file.write(compressed_data) - - return num_images - - -@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test"))) -def stanford_cars(root, config): - import scipy.io as io - from numpy.core.records import fromarrays - - split = config["split"] - num_samples = {"train": 5, "test": 7}[split] - num_categories = 3 - - if split == "train": - images_folder_name = "cars_train" - devkit = root / "devkit" - devkit.mkdir() - annotations_mat_path = devkit / "cars_train_annos.mat" - else: - images_folder_name = "cars_test" - annotations_mat_path = root / "cars_test_annos_withlabels.mat" - - create_image_folder( - root=root, - name=images_folder_name, - file_name_fn=lambda image_index: f"{image_index:5d}.jpg", - num_examples=num_samples, - ) - - make_tar(root, f"cars_{split}.tgz", images_folder_name) - bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) - classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) - fnames = [f"{i:5d}.jpg" for i in range(num_samples)] - rec_array = fromarrays( - [bbox, bbox, bbox, bbox, classes, fnames], - names=["bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2", "class", "fname"], - ) - - io.savemat(annotations_mat_path, {"annotations": rec_array}) - if split == "train": - make_tar(root, "car_devkit.tgz", devkit, compression="gz") - - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def usps(root, config): - num_samples = {"train": 15, "test": 7}[config["split"]] - - with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh: - lines = [] - for _ in range(num_samples): - label = make_tensor(1, low=1, high=11, dtype=torch.int) - values = make_tensor(256, low=-1, high=1, dtype=torch.float) - lines.append( - " ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))]) - ) - - fh.write("\n".join(lines).encode()) - - return num_samples diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py deleted file mode 100644 index e9192f44f52..00000000000 --- a/test/prototype_common_utils.py +++ /dev/null @@ -1,529 +0,0 @@ -"""This module is separated from common_utils.py to prevent the former to be dependent on torchvision.prototype""" - -import collections.abc -import dataclasses -import functools -from typing import Callable, Optional, Sequence, Tuple, Union - -import PIL.Image -import pytest -import torch -import torch.testing -from datasets_utils import combinations_grid -from torch.nn.functional import one_hot -from torch.testing._comparison import ( - assert_equal as _assert_equal, - BooleanPair, - ErrorMeta, - NonePair, - NumberPair, - TensorLikePair, - UnsupportedInputs, -) -from torchvision.prototype import features -from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor -from torchvision.transforms.functional_tensor import _max_value as get_max_value - -__all__ = [ - "assert_close", - "assert_equal", - "ArgsKwargs", - "make_image_loaders", - "make_image", - "make_images", - "make_bounding_box_loaders", - "make_bounding_box", - "make_bounding_boxes", - "make_label", - "make_one_hot_labels", - "make_detection_mask_loaders", - "make_detection_mask", - "make_detection_masks", - "make_segmentation_mask_loaders", - "make_segmentation_mask", - "make_segmentation_masks", - "make_mask_loaders", - "make_masks", -] - - -class PILImagePair(TensorLikePair): - def __init__( - self, - actual, - expected, - *, - agg_method=None, - allowed_percentage_diff=None, - **other_parameters, - ): - if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)): - raise UnsupportedInputs() - - # This parameter is ignored to enable checking PIL images to tensor images no on the CPU - other_parameters["check_device"] = False - - super().__init__(actual, expected, **other_parameters) - self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method - self.allowed_percentage_diff = allowed_percentage_diff - - def _process_inputs(self, actual, expected, *, id, allow_subclasses): - actual, expected = [ - to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input) - for input in [actual, expected] - ] - # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL - # image to a tensor adds a singleton leading dimension. - # Although it looks like this belongs in `self._equalize_attributes`, it has to happen here. - # `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional - # shape check that will fail if we don't broadcast before. - try: - actual, expected = torch.broadcast_tensors(actual, expected) - except RuntimeError: - raise ErrorMeta( - AssertionError, - f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.", - id=id, - ) from None - return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses) - - def _equalize_attributes(self, actual, expected): - if actual.dtype != expected.dtype: - dtype = torch.promote_types(actual.dtype, expected.dtype) - actual = convert_image_dtype(actual, dtype) - expected = convert_image_dtype(expected, dtype) - - return super()._equalize_attributes(actual, expected) - - def compare(self) -> None: - actual, expected = self.actual, self.expected - - self._compare_attributes(actual, expected) - - actual, expected = self._equalize_attributes(actual, expected) - abs_diff = torch.abs(actual - expected) - - if self.allowed_percentage_diff is not None: - percentage_diff = (abs_diff != 0).to(torch.float).mean() - if percentage_diff > self.allowed_percentage_diff: - self._make_error_meta(AssertionError, "percentage mismatch") - - if self.agg_method is None: - super()._compare_values(actual, expected) - else: - err = self.agg_method(abs_diff.to(torch.float64)) - if err > self.atol: - self._make_error_meta(AssertionError, "aggregated mismatch") - - -def assert_close( - actual, - expected, - *, - allow_subclasses=True, - rtol=None, - atol=None, - equal_nan=False, - check_device=True, - check_dtype=True, - check_layout=True, - check_stride=False, - msg=None, - **kwargs, -): - """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison""" - __tracebackhide__ = True - - _assert_equal( - actual, - expected, - pair_types=( - NonePair, - BooleanPair, - NumberPair, - PILImagePair, - TensorLikePair, - ), - allow_subclasses=allow_subclasses, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - check_device=check_device, - check_dtype=check_dtype, - check_layout=check_layout, - check_stride=check_stride, - msg=msg, - **kwargs, - ) - - -assert_equal = functools.partial(assert_close, rtol=0, atol=0) - - -class ArgsKwargs: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - def __iter__(self): - yield self.args - yield self.kwargs - - def load(self, device="cpu"): - args = tuple(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args) - kwargs = { - keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg for keyword, arg in self.kwargs.items() - } - return args, kwargs - - -DEFAULT_SQUARE_IMAGE_SIZE = 15 -DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33) -DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9) -DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random") - - -def _parse_image_size(size, *, name="size"): - if size == "random": - return tuple(torch.randint(15, 33, (2,)).tolist()) - elif isinstance(size, int) and size > 0: - return (size, size) - elif ( - isinstance(size, collections.abc.Sequence) - and len(size) == 2 - and all(isinstance(length, int) and length > 0 for length in size) - ): - return tuple(size) - else: - raise pytest.UsageError( - f"'{name}' can either be `'random'`, a positive integer, or a sequence of two positive integers," - f"but got {size} instead." - ) - - -DEFAULT_EXTRA_DIMS = ((), (0,), (4,), (2, 3), (5, 0), (0, 5)) - - -def from_loader(loader_fn): - def wrapper(*args, **kwargs): - loader = loader_fn(*args, **kwargs) - return loader.load(kwargs.get("device", "cpu")) - - return wrapper - - -def from_loaders(loaders_fn): - def wrapper(*args, **kwargs): - loaders = loaders_fn(*args, **kwargs) - for loader in loaders: - yield loader.load(kwargs.get("device", "cpu")) - - return wrapper - - -@dataclasses.dataclass -class TensorLoader: - fn: Callable[[Sequence[int], torch.dtype, Union[str, torch.device]], torch.Tensor] - shape: Sequence[int] - dtype: torch.dtype - - def load(self, device): - return self.fn(self.shape, self.dtype, device) - - -@dataclasses.dataclass -class ImageLoader(TensorLoader): - color_space: features.ColorSpace - image_size: Tuple[int, int] = dataclasses.field(init=False) - num_channels: int = dataclasses.field(init=False) - - def __post_init__(self): - self.image_size = self.shape[-2:] - self.num_channels = self.shape[-3] - - -def make_image_loader( - size="random", - *, - color_space=features.ColorSpace.RGB, - extra_dims=(), - dtype=torch.float32, - constant_alpha=True, -): - size = _parse_image_size(size) - - try: - num_channels = { - features.ColorSpace.GRAY: 1, - features.ColorSpace.GRAY_ALPHA: 2, - features.ColorSpace.RGB: 3, - features.ColorSpace.RGB_ALPHA: 4, - }[color_space] - except KeyError as error: - raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error - - def fn(shape, dtype, device): - max_value = get_max_value(dtype) - data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) - if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: - data[..., -1, :, :] = max_value - return features.Image(data, color_space=color_space) - - return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) - - -make_image = from_loader(make_image_loader) - - -def make_image_loaders( - *, - sizes=DEFAULT_IMAGE_SIZES, - color_spaces=( - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, - ), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.float32, torch.uint8), - constant_alpha=True, -): - for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes): - yield make_image_loader(**params, constant_alpha=constant_alpha) - - -make_images = from_loaders(make_image_loaders) - - -@dataclasses.dataclass -class BoundingBoxLoader(TensorLoader): - format: features.BoundingBoxFormat - image_size: Tuple[int, int] - - -def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): - low, high = torch.broadcast_tensors( - *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] - ) - return torch.stack( - [ - torch.randint(low_scalar, high_scalar, (), **kwargs) - for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) - ] - ).reshape(low.shape) - - -def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32): - if isinstance(format, str): - format = features.BoundingBoxFormat[format] - if format not in { - features.BoundingBoxFormat.XYXY, - features.BoundingBoxFormat.XYWH, - features.BoundingBoxFormat.CXCYWH, - }: - raise pytest.UsageError(f"Can't make bounding box in format {format}") - - image_size = _parse_image_size(image_size, name="image_size") - - def fn(shape, dtype, device): - *extra_dims, num_coordinates = shape - if num_coordinates != 4: - raise pytest.UsageError() - - if any(dim == 0 for dim in extra_dims): - return features.BoundingBox( - torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, image_size=image_size - ) - - height, width = image_size - - if format == features.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, extra_dims) - y1 = torch.randint(0, height // 2, extra_dims) - x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 - y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 - parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, extra_dims) - y = torch.randint(0, height // 2, extra_dims) - w = randint_with_tensor_bounds(1, width - x) - h = randint_with_tensor_bounds(1, height - y) - parts = (x, y, w, h) - else: # format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) - w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) - h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) - parts = (cx, cy, w, h) - - return features.BoundingBox( - torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, image_size=image_size - ) - - return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, image_size=image_size) - - -make_bounding_box = from_loader(make_bounding_box_loader) - - -def make_bounding_box_loaders( - *, - extra_dims=DEFAULT_EXTRA_DIMS, - formats=tuple(features.BoundingBoxFormat), - image_size="random", - dtypes=(torch.float32, torch.int64), -): - for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): - yield make_bounding_box_loader(**params, image_size=image_size) - - -make_bounding_boxes = from_loaders(make_bounding_box_loaders) - - -@dataclasses.dataclass -class LabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def _parse_categories(categories): - if categories is None: - num_categories = int(torch.randint(1, 11, ())) - elif isinstance(categories, int): - num_categories = categories - categories = [f"category{idx}" for idx in range(num_categories)] - elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories): - categories = list(categories) - num_categories = len(categories) - else: - raise pytest.UsageError( - f"`categories` can either be `None` (default), an integer, or a sequence of strings, " - f"but got '{categories}' instead." - ) - return categories, num_categories - - -def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, - # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 - data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) - return features.Label(data, categories=categories) - - return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) - - -make_label = from_loader(make_label_loader) - - -@dataclasses.dataclass -class OneHotLabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - if num_categories == 0: - data = torch.empty(shape, dtype=dtype, device=device) - else: - # The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional - # since `one_hot` only supports int64 - label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) - data = one_hot(label, num_classes=num_categories).to(dtype) - return features.OneHotLabel(data, categories=categories) - - return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) - - -def make_one_hot_label_loaders( - *, - categories=(1, 0, None), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.int64, torch.float32), -): - for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes): - yield make_one_hot_label_loader(**params) - - -make_one_hot_labels = from_loaders(make_one_hot_label_loaders) - - -class MaskLoader(TensorLoader): - pass - - -def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8): - # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects - size = _parse_image_size(size) - num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects - - def fn(shape, dtype, device): - data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) - return features.Mask(data) - - return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) - - -make_detection_mask = from_loader(make_detection_mask_loader) - - -def make_detection_mask_loaders( - sizes=DEFAULT_IMAGE_SIZES, - num_objects=(1, 0, "random"), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.uint8,), -): - for params in combinations_grid(size=sizes, num_objects=num_objects, extra_dims=extra_dims, dtype=dtypes): - yield make_detection_mask_loader(**params) - - -make_detection_masks = from_loaders(make_detection_mask_loaders) - - -def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8): - # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values - size = _parse_image_size(size) - num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories - - def fn(shape, dtype, device): - data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) - return features.Mask(data) - - return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) - - -make_segmentation_mask = from_loader(make_segmentation_mask_loader) - - -def make_segmentation_mask_loaders( - *, - sizes=DEFAULT_IMAGE_SIZES, - num_categories=(1, 2, "random"), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.uint8,), -): - for params in combinations_grid(size=sizes, num_categories=num_categories, extra_dims=extra_dims, dtype=dtypes): - yield make_segmentation_mask_loader(**params) - - -make_segmentation_masks = from_loaders(make_segmentation_mask_loaders) - - -def make_mask_loaders( - *, - sizes=DEFAULT_IMAGE_SIZES, - num_objects=(1, 0, "random"), - num_categories=(1, 2, "random"), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.uint8,), -): - yield from make_detection_mask_loaders(sizes=sizes, num_objects=num_objects, extra_dims=extra_dims, dtypes=dtypes) - yield from make_segmentation_mask_loaders( - sizes=sizes, num_categories=num_categories, extra_dims=extra_dims, dtypes=dtypes - ) - - -make_masks = from_loaders(make_mask_loaders) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py deleted file mode 100644 index 99a9066be0a..00000000000 --- a/test/prototype_transforms_dispatcher_infos.py +++ /dev/null @@ -1,259 +0,0 @@ -import dataclasses -from collections import defaultdict -from typing import Callable, Dict, List, Sequence, Type - -import pytest -import torchvision.prototype.transforms.functional as F -from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip -from torchvision.prototype import features - -__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] - -KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS} - - -def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"): - return Skip( - "test_scripted_smoke", - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)), - reason=reason, - ) - - -def skip_integer_size_jit(name="size"): - return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.") - - -@dataclasses.dataclass -class DispatcherInfo: - dispatcher: Callable - kernels: Dict[Type, Callable] - skips: Sequence[Skip] = dataclasses.field(default_factory=list) - _skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) - - def __post_init__(self): - skips_map = defaultdict(list) - for skip in self.skips: - skips_map[skip.test_name].append(skip) - self._skips_map = dict(skips_map) - - def sample_inputs(self, *types): - for type in types or self.kernels.keys(): - if type not in self.kernels: - raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}") - - yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]() - - def maybe_skip(self, *, test_name, args_kwargs, device): - skips = self._skips_map.get(test_name) - if not skips: - return - - for skip in skips: - if skip.condition(args_kwargs, device): - pytest.skip(skip.reason) - - -DISPATCHER_INFOS = [ - DispatcherInfo( - F.horizontal_flip, - kernels={ - features.Image: F.horizontal_flip_image_tensor, - features.BoundingBox: F.horizontal_flip_bounding_box, - features.Mask: F.horizontal_flip_mask, - }, - ), - DispatcherInfo( - F.resize, - kernels={ - features.Image: F.resize_image_tensor, - features.BoundingBox: F.resize_bounding_box, - features.Mask: F.resize_mask, - }, - skips=[ - skip_integer_size_jit(), - ], - ), - DispatcherInfo( - F.affine, - kernels={ - features.Image: F.affine_image_tensor, - features.BoundingBox: F.affine_bounding_box, - features.Mask: F.affine_mask, - }, - skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], - ), - DispatcherInfo( - F.vertical_flip, - kernels={ - features.Image: F.vertical_flip_image_tensor, - features.BoundingBox: F.vertical_flip_bounding_box, - features.Mask: F.vertical_flip_mask, - }, - ), - DispatcherInfo( - F.rotate, - kernels={ - features.Image: F.rotate_image_tensor, - features.BoundingBox: F.rotate_bounding_box, - features.Mask: F.rotate_mask, - }, - ), - DispatcherInfo( - F.crop, - kernels={ - features.Image: F.crop_image_tensor, - features.BoundingBox: F.crop_bounding_box, - features.Mask: F.crop_mask, - }, - ), - DispatcherInfo( - F.resized_crop, - kernels={ - features.Image: F.resized_crop_image_tensor, - features.BoundingBox: F.resized_crop_bounding_box, - features.Mask: F.resized_crop_mask, - }, - ), - DispatcherInfo( - F.pad, - kernels={ - features.Image: F.pad_image_tensor, - features.BoundingBox: F.pad_bounding_box, - features.Mask: F.pad_mask, - }, - ), - DispatcherInfo( - F.perspective, - kernels={ - features.Image: F.perspective_image_tensor, - features.BoundingBox: F.perspective_bounding_box, - features.Mask: F.perspective_mask, - }, - ), - DispatcherInfo( - F.elastic, - kernels={ - features.Image: F.elastic_image_tensor, - features.BoundingBox: F.elastic_bounding_box, - features.Mask: F.elastic_mask, - }, - ), - DispatcherInfo( - F.center_crop, - kernels={ - features.Image: F.center_crop_image_tensor, - features.BoundingBox: F.center_crop_bounding_box, - features.Mask: F.center_crop_mask, - }, - skips=[ - skip_integer_size_jit("output_size"), - ], - ), - DispatcherInfo( - F.gaussian_blur, - kernels={ - features.Image: F.gaussian_blur_image_tensor, - }, - skips=[ - skip_python_scalar_arg_jit("kernel_size"), - skip_python_scalar_arg_jit("sigma"), - ], - ), - DispatcherInfo( - F.equalize, - kernels={ - features.Image: F.equalize_image_tensor, - }, - ), - DispatcherInfo( - F.invert, - kernels={ - features.Image: F.invert_image_tensor, - }, - ), - DispatcherInfo( - F.posterize, - kernels={ - features.Image: F.posterize_image_tensor, - }, - ), - DispatcherInfo( - F.solarize, - kernels={ - features.Image: F.solarize_image_tensor, - }, - ), - DispatcherInfo( - F.autocontrast, - kernels={ - features.Image: F.autocontrast_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_sharpness, - kernels={ - features.Image: F.adjust_sharpness_image_tensor, - }, - ), - DispatcherInfo( - F.erase, - kernels={ - features.Image: F.erase_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_brightness, - kernels={ - features.Image: F.adjust_brightness_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_contrast, - kernels={ - features.Image: F.adjust_contrast_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_gamma, - kernels={ - features.Image: F.adjust_gamma_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_hue, - kernels={ - features.Image: F.adjust_hue_image_tensor, - }, - ), - DispatcherInfo( - F.adjust_saturation, - kernels={ - features.Image: F.adjust_saturation_image_tensor, - }, - ), - DispatcherInfo( - F.five_crop, - kernels={ - features.Image: F.five_crop_image_tensor, - }, - skips=[ - skip_integer_size_jit(), - ], - ), - DispatcherInfo( - F.ten_crop, - kernels={ - features.Image: F.ten_crop_image_tensor, - }, - skips=[ - skip_integer_size_jit(), - ], - ), - DispatcherInfo( - F.normalize, - kernels={ - features.Image: F.normalize_image_tensor, - }, - ), -] diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py deleted file mode 100644 index 3f050ad8f7d..00000000000 --- a/test/prototype_transforms_kernel_infos.py +++ /dev/null @@ -1,1594 +0,0 @@ -import dataclasses -import functools -import itertools -import math -from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence - -import numpy as np -import pytest -import torch.testing -import torchvision.ops -import torchvision.prototype.transforms.functional as F -from datasets_utils import combinations_grid -from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders -from torchvision.prototype import features -from torchvision.transforms.functional_tensor import _max_value as get_max_value - -__all__ = ["KernelInfo", "KERNEL_INFOS"] - - -@dataclasses.dataclass -class Skip: - test_name: str - reason: str - condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True - - -@dataclasses.dataclass -class KernelInfo: - kernel: Callable - # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should - # not include extensive parameter combinations to keep to overall test count moderate. - sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]] - # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name - # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then - kernel_name: Optional[str] = None - # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take - # tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen - # inside the function. It should return a tensor or to be more precise an object that can be compared to a - # tensor by `assert_close`. If omitted, no reference test will be performed. - reference_fn: Optional[Callable] = None - # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter - # values to be tested. If not specified, `sample_inputs_fn` will be used. - reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None - # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. - closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - skips: Sequence[Skip] = dataclasses.field(default_factory=list) - _skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) - - def __post_init__(self): - self.kernel_name = self.kernel_name or self.kernel.__name__ - self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn - - skips_map = defaultdict(list) - for skip in self.skips: - skips_map[skip.test_name].append(skip) - self._skips_map = dict(skips_map) - - def maybe_skip(self, *, test_name, args_kwargs, device): - skips = self._skips_map.get(test_name) - if not skips: - return - - for skip in skips: - if skip.condition(args_kwargs, device): - pytest.skip(skip.reason) - - -DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( - atol=1e-5, - rtol=0, - agg_method="mean", -) - - -def pil_reference_wrapper(pil_kernel): - @functools.wraps(pil_kernel) - def wrapper(image_tensor, *other_args, **kwargs): - if image_tensor.ndim > 3: - raise pytest.UsageError( - f"Can only test single tensor images against PIL, but input has shape {image_tensor.shape}" - ) - - # We don't need to convert back to tensor here, since `assert_close` does that automatically. - return pil_kernel(F.to_image_pil(image_tensor), *other_args, **kwargs) - - return wrapper - - -def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"): - return Skip( - "test_scripted_vs_eager", - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)), - reason=reason, - ) - - -def skip_integer_size_jit(name="size"): - return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.") - - -KERNEL_INFOS = [] - - -def sample_inputs_horizontal_flip_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): - yield ArgsKwargs(image_loader) - - -def reference_inputs_horizontal_flip_image_tensor(): - for image_loader in make_image_loaders(extra_dims=[()]): - yield ArgsKwargs(image_loader) - - -def sample_inputs_horizontal_flip_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders( - formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] - ): - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size - ) - - -def sample_inputs_horizontal_flip_mask(): - for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]): - yield ArgsKwargs(image_loader) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.horizontal_flip_image_tensor, - kernel_name="horizontal_flip_image_tensor", - sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor, - reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil), - reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.horizontal_flip_bounding_box, - sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box, - ), - KernelInfo( - F.horizontal_flip_mask, - sample_inputs_fn=sample_inputs_horizontal_flip_mask, - ), - ] -) - - -def _get_resize_sizes(image_size): - height, width = image_size - length = max(image_size) - # FIXME: enable me when the kernels are fixed - # yield length - yield [length] - yield (length,) - new_height = int(height * 0.75) - new_width = int(width * 1.25) - yield [new_height, new_width] - yield height, width - - -def sample_inputs_resize_image_tensor(): - for image_loader, interpolation in itertools.product( - make_image_loaders(dtypes=[torch.float32]), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BICUBIC, - ], - ): - for size in _get_resize_sizes(image_loader.image_size): - yield ArgsKwargs(image_loader, size=size, interpolation=interpolation) - - -@pil_reference_wrapper -def reference_resize_image_tensor(*args, **kwargs): - if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in { - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - }: - raise pytest.UsageError("Anti-aliasing is always active in PIL") - return F.resize_image_pil(*args, **kwargs) - - -def reference_inputs_resize_image_tensor(): - for image_loader, interpolation in itertools.product( - make_image_loaders(extra_dims=[()]), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - ], - ): - for size in _get_resize_sizes(image_loader.image_size): - yield ArgsKwargs( - image_loader, - size=size, - interpolation=interpolation, - antialias=interpolation - in { - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - }, - ) - - -def sample_inputs_resize_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(formats=[features.BoundingBoxFormat.XYXY]): - for size in _get_resize_sizes(bounding_box_loader.image_size): - yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size) - - -def sample_inputs_resize_mask(): - for mask_loader in make_mask_loaders(dtypes=[torch.uint8]): - for size in _get_resize_sizes(mask_loader.shape[-2:]): - yield ArgsKwargs(mask_loader, size=size) - - -@pil_reference_wrapper -def reference_resize_mask(*args, **kwargs): - return F.resize_image_pil(*args, interpolation=F.InterpolationMode.NEAREST, **kwargs) - - -def reference_inputs_resize_mask(): - for mask_loader in make_mask_loaders(extra_dims=[()], num_objects=[1]): - for size in _get_resize_sizes(mask_loader.shape[-2:]): - yield ArgsKwargs(mask_loader, size=size) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.resize_image_tensor, - sample_inputs_fn=sample_inputs_resize_image_tensor, - reference_fn=reference_resize_image_tensor, - reference_inputs_fn=reference_inputs_resize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit(), - ], - ), - KernelInfo( - F.resize_bounding_box, - sample_inputs_fn=sample_inputs_resize_bounding_box, - skips=[ - skip_integer_size_jit(), - ], - ), - KernelInfo( - F.resize_mask, - sample_inputs_fn=sample_inputs_resize_mask, - reference_fn=reference_resize_mask, - reference_inputs_fn=reference_inputs_resize_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit(), - ], - ), - ] -) - - -_AFFINE_KWARGS = combinations_grid( - angle=[-87, 15, 90], - translate=[(5, 5), (-5, -5)], - scale=[0.77, 1.27], - shear=[(12, 12), (0, 0)], -) - - -def _diversify_affine_kwargs_types(affine_kwargs): - angle = affine_kwargs["angle"] - for diverse_angle in [int(angle), float(angle)]: - yield dict(affine_kwargs, angle=diverse_angle) - - shear = affine_kwargs["shear"] - for diverse_shear in [tuple(shear), list(shear), int(shear[0]), float(shear[0])]: - yield dict(affine_kwargs, shear=diverse_shear) - - -def sample_inputs_affine_image_tensor(): - for image_loader, interpolation_mode, center in itertools.product( - make_image_loaders(sizes=["random"], dtypes=[torch.float32]), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BILINEAR, - ], - [None, (0, 0)], - ): - for fill in [None, 128.0, 128, [12.0], [0.5] * image_loader.num_channels]: - yield ArgsKwargs( - image_loader, - interpolation=interpolation_mode, - center=center, - fill=fill, - **_AFFINE_KWARGS[0], - ) - - for image_loader, affine_kwargs in itertools.product( - make_image_loaders(sizes=["random"], dtypes=[torch.float32]), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0]) - ): - yield ArgsKwargs(image_loader, **affine_kwargs) - - -def reference_inputs_affine_image_tensor(): - for image_loader, affine_kwargs in itertools.product(make_image_loaders(extra_dims=[()]), _AFFINE_KWARGS): - yield ArgsKwargs( - image_loader, - interpolation=F.InterpolationMode.NEAREST, - **affine_kwargs, - ) - - -def sample_inputs_affine_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, - **_AFFINE_KWARGS[0], - ) - - for bounding_box_loader, affine_kwargs in itertools.product( - make_bounding_box_loaders(), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0]) - ): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, - **affine_kwargs, - ) - - -def _compute_affine_matrix(angle, translate, scale, shear, center): - rot = math.radians(angle) - cx, cy = center - tx, ty = translate - sx, sy = [math.radians(sh_) for sh_ in shear] - - c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - c_matrix_inv = np.linalg.inv(c_matrix) - rs_matrix = np.array( - [ - [scale * math.cos(rot), -scale * math.sin(rot), 0], - [scale * math.sin(rot), scale * math.cos(rot), 0], - [0, 0, 1], - ] - ) - shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) - rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) - true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) - return true_matrix - - -def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center=None): - if center is None: - center = [s * 0.5 for s in image_size[::-1]] - - def transform(bbox): - affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) - affine_matrix = affine_matrix[:2, :] - - bbox_xyxy = F.convert_format_bounding_box(bbox, old_format=format, new_format=features.BoundingBoxFormat.XYXY) - points = np.array( - [ - [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], - ] - ) - transformed_points = np.matmul(points, affine_matrix.T) - out_bbox = torch.tensor( - [ - np.min(transformed_points[:, 0]), - np.min(transformed_points[:, 1]), - np.max(transformed_points[:, 0]), - np.max(transformed_points[:, 1]), - ], - dtype=bbox.dtype, - ) - return F.convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ) - - if bounding_box.ndim < 2: - bounding_box = [bounding_box] - - expected_bboxes = [transform(bbox) for bbox in bounding_box] - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - - return expected_bboxes - - -def reference_inputs_affine_bounding_box(): - for bounding_box_loader, affine_kwargs in itertools.product( - make_bounding_box_loaders(extra_dims=[()]), - _AFFINE_KWARGS, - ): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, - **affine_kwargs, - ) - - -def sample_inputs_affine_image_mask(): - for mask_loader, center in itertools.product( - make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]), - [None, (0, 0)], - ): - yield ArgsKwargs(mask_loader, center=center, **_AFFINE_KWARGS[0]) - - for mask_loader, affine_kwargs in itertools.product( - make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]), _diversify_affine_kwargs_types(_AFFINE_KWARGS[0]) - ): - yield ArgsKwargs(mask_loader, **affine_kwargs) - - -@pil_reference_wrapper -def reference_affine_mask(*args, **kwargs): - return F.affine_image_pil(*args, interpolation=F.InterpolationMode.NEAREST, **kwargs) - - -def reference_inputs_resize_mask(): - for mask_loader, affine_kwargs in itertools.product( - make_mask_loaders(extra_dims=[()], num_objects=[1]), _AFFINE_KWARGS - ): - yield ArgsKwargs(mask_loader, **affine_kwargs) - - -# FIXME: @datumbox, remove this as soon as you have fixed the behavior in https://github.com/pytorch/vision/pull/6636 -def skip_scalar_shears(*test_names): - for test_name in test_names: - yield Skip( - test_name, - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["shear"], (int, float)), - reason="The kernel is broken for a scalar `shear`", - ) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.affine_image_tensor, - sample_inputs_fn=sample_inputs_affine_image_tensor, - reference_fn=pil_reference_wrapper(F.affine_image_pil), - reference_inputs_fn=reference_inputs_affine_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], - ), - KernelInfo( - F.affine_bounding_box, - sample_inputs_fn=sample_inputs_affine_bounding_box, - reference_fn=reference_affine_bounding_box, - reference_inputs_fn=reference_inputs_affine_bounding_box, - closeness_kwargs=dict(atol=1, rtol=0), - skips=[ - skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"), - *skip_scalar_shears( - "test_batched_vs_single", - "test_no_inplace", - "test_dtype_and_device_consistency", - ), - ], - ), - KernelInfo( - F.affine_mask, - sample_inputs_fn=sample_inputs_affine_image_mask, - reference_fn=reference_affine_mask, - reference_inputs_fn=reference_inputs_resize_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], - ), - ] -) - - -def sample_inputs_convert_format_bounding_box(): - formats = set(features.BoundingBoxFormat) - for bounding_box_loader in make_bounding_box_loaders(formats=formats): - old_format = bounding_box_loader.format - for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)): - yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params) - - -def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy): - if not copy: - raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`") - - return torchvision.ops.box_convert( - bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower() - ) - - -def reference_inputs_convert_format_bounding_box(): - for args_kwargs in sample_inputs_convert_color_space_image_tensor(): - (image_loader, *other_args), kwargs = args_kwargs - if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True): - yield args_kwargs - - -KERNEL_INFOS.append( - KernelInfo( - F.convert_format_bounding_box, - sample_inputs_fn=sample_inputs_convert_format_bounding_box, - reference_fn=reference_convert_format_bounding_box, - reference_inputs_fn=reference_inputs_convert_format_bounding_box, - ), -) - - -def sample_inputs_convert_color_space_image_tensor(): - color_spaces = set(features.ColorSpace) - {features.ColorSpace.OTHER} - for image_loader in make_image_loaders(sizes=["random"], color_spaces=color_spaces, constant_alpha=True): - old_color_space = image_loader.color_space - for params in combinations_grid(new_color_space=color_spaces - {old_color_space}, copy=(True, False)): - yield ArgsKwargs(image_loader, old_color_space=old_color_space, **params) - - -@pil_reference_wrapper -def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy): - color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode) - if color_space_pil != old_color_space: - raise pytest.UsageError( - f"Converting the tensor image into an PIL image changed the colorspace " - f"from {old_color_space} to {color_space_pil}" - ) - - return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy) - - -def reference_inputs_convert_color_space_image_tensor(): - for args_kwargs in sample_inputs_convert_color_space_image_tensor(): - (image_loader, *other_args), kwargs = args_kwargs - if len(image_loader.shape) == 3: - yield args_kwargs - - -KERNEL_INFOS.append( - KernelInfo( - F.convert_color_space_image_tensor, - sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, - reference_fn=reference_convert_color_space_image_tensor, - reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), -) - - -def sample_inputs_vertical_flip_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): - yield ArgsKwargs(image_loader) - - -def reference_inputs_vertical_flip_image_tensor(): - for image_loader in make_image_loaders(extra_dims=[()]): - yield ArgsKwargs(image_loader) - - -def sample_inputs_vertical_flip_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders( - formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] - ): - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size - ) - - -def sample_inputs_vertical_flip_mask(): - for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]): - yield ArgsKwargs(image_loader) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.vertical_flip_image_tensor, - kernel_name="vertical_flip_image_tensor", - sample_inputs_fn=sample_inputs_vertical_flip_image_tensor, - reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil), - reference_inputs_fn=reference_inputs_vertical_flip_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.vertical_flip_bounding_box, - sample_inputs_fn=sample_inputs_vertical_flip_bounding_box, - ), - KernelInfo( - F.vertical_flip_mask, - sample_inputs_fn=sample_inputs_vertical_flip_mask, - ), - ] -) - -_ROTATE_ANGLES = [-87, 15, 90] - - -def sample_inputs_rotate_image_tensor(): - for image_loader, params in itertools.product( - make_image_loaders(sizes=["random"], dtypes=[torch.float32]), - combinations_grid( - interpolation=[F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR], - expand=[True, False], - center=[None, (0, 0)], - ), - ): - if params["center"] is not None and params["expand"]: - # Otherwise this will emit a warning and ignore center anyway - continue - - for fill in [None, 0.5, [0.5] * image_loader.num_channels]: - yield ArgsKwargs( - image_loader, - angle=_ROTATE_ANGLES[0], - fill=fill, - **params, - ) - - -def reference_inputs_rotate_image_tensor(): - for image_loader, angle in itertools.product(make_image_loaders(extra_dims=[()]), _ROTATE_ANGLES): - yield ArgsKwargs(image_loader, angle=angle) - - -def sample_inputs_rotate_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, - angle=_ROTATE_ANGLES[0], - ) - - -def sample_inputs_rotate_mask(): - for image_loader, params in itertools.product( - make_image_loaders(sizes=["random"], dtypes=[torch.uint8]), - combinations_grid( - expand=[True, False], - center=[None, (0, 0)], - ), - ): - if params["center"] is not None and params["expand"]: - # Otherwise this will emit a warning and ignore center anyway - continue - - yield ArgsKwargs( - image_loader, - angle=_ROTATE_ANGLES[0], - **params, - ) - - -@pil_reference_wrapper -def reference_rotate_mask(*args, **kwargs): - return F.rotate_image_pil(*args, interpolation=F.InterpolationMode.NEAREST, **kwargs) - - -def reference_inputs_rotate_mask(): - for mask_loader, angle in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _ROTATE_ANGLES): - yield ArgsKwargs(mask_loader, angle=angle) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.rotate_image_tensor, - sample_inputs_fn=sample_inputs_rotate_image_tensor, - reference_fn=pil_reference_wrapper(F.rotate_image_pil), - reference_inputs_fn=reference_inputs_rotate_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.rotate_bounding_box, - sample_inputs_fn=sample_inputs_rotate_bounding_box, - ), - KernelInfo( - F.rotate_mask, - sample_inputs_fn=sample_inputs_rotate_mask, - reference_fn=reference_rotate_mask, - reference_inputs_fn=reference_inputs_rotate_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - -_CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20], width=[12, 20]) - - -def sample_inputs_crop_image_tensor(): - for image_loader, params in itertools.product(make_image_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]): - yield ArgsKwargs(image_loader, **params) - - -def reference_inputs_crop_image_tensor(): - for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _CROP_PARAMS): - yield ArgsKwargs(image_loader, **params) - - -def sample_inputs_crop_bounding_box(): - for bounding_box_loader, params in itertools.product( - make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] - ): - yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) - - -def sample_inputs_crop_mask(): - for mask_loader, params in itertools.product(make_mask_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]): - yield ArgsKwargs(mask_loader, **params) - - -def reference_inputs_crop_mask(): - for mask_loader, params in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _CROP_PARAMS): - yield ArgsKwargs(mask_loader, **params) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.crop_image_tensor, - kernel_name="crop_image_tensor", - sample_inputs_fn=sample_inputs_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.crop_image_pil), - reference_inputs_fn=reference_inputs_crop_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.crop_bounding_box, - sample_inputs_fn=sample_inputs_crop_bounding_box, - ), - KernelInfo( - F.crop_mask, - sample_inputs_fn=sample_inputs_crop_mask, - reference_fn=pil_reference_wrapper(F.crop_image_pil), - reference_inputs_fn=reference_inputs_crop_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - -_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)]) - - -def sample_inputs_resized_crop_image_tensor(): - for image_loader in make_image_loaders(): - yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0]) - - -@pil_reference_wrapper -def reference_resized_crop_image_tensor(*args, **kwargs): - if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in { - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - }: - raise pytest.UsageError("Anti-aliasing is always active in PIL") - return F.resized_crop_image_pil(*args, **kwargs) - - -def reference_inputs_resized_crop_image_tensor(): - for image_loader, interpolation, params in itertools.product( - make_image_loaders(extra_dims=[()]), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - ], - _RESIZED_CROP_PARAMS, - ): - yield ArgsKwargs( - image_loader, - interpolation=interpolation, - antialias=interpolation - in { - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - }, - **params, - ) - - -def sample_inputs_resized_crop_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0]) - - -def sample_inputs_resized_crop_mask(): - for mask_loader in make_mask_loaders(): - yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0]) - - -def reference_inputs_resized_crop_mask(): - for mask_loader, params in itertools.product( - make_mask_loaders(extra_dims=[()], num_objects=[1]), _RESIZED_CROP_PARAMS - ): - yield ArgsKwargs(mask_loader, **params) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.resized_crop_image_tensor, - sample_inputs_fn=sample_inputs_resized_crop_image_tensor, - reference_fn=reference_resized_crop_image_tensor, - reference_inputs_fn=reference_inputs_resized_crop_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.resized_crop_bounding_box, - sample_inputs_fn=sample_inputs_resized_crop_bounding_box, - ), - KernelInfo( - F.resized_crop_mask, - sample_inputs_fn=sample_inputs_resized_crop_mask, - reference_fn=pil_reference_wrapper(F.resized_crop_image_pil), - reference_inputs_fn=reference_inputs_resized_crop_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - -_PAD_PARAMS = combinations_grid( - padding=[[1], [1, 1], [1, 1, 2, 2]], - padding_mode=["constant", "symmetric", "edge", "reflect"], -) - - -def sample_inputs_pad_image_tensor(): - for image_loader, params in itertools.product(make_image_loaders(sizes=["random"]), _PAD_PARAMS): - fills = [None, 128.0, 128, [12.0]] - if params["padding_mode"] == "constant": - fills.append([12.0 + c for c in range(image_loader.num_channels)]) - for fill in fills: - yield ArgsKwargs(image_loader, fill=fill, **params) - - -def reference_inputs_pad_image_tensor(): - for image_loader, params in itertools.product(make_image_loaders(extra_dims=[()]), _PAD_PARAMS): - # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? - fills = [None, 128.0, 128] - if params["padding_mode"] == "constant": - fills.append([12.0 + c for c in range(image_loader.num_channels)]) - for fill in fills: - yield ArgsKwargs(image_loader, fill=fill, **params) - - -def sample_inputs_pad_bounding_box(): - for bounding_box_loader, params in itertools.product(make_bounding_box_loaders(), _PAD_PARAMS): - if params["padding_mode"] != "constant": - continue - - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size, **params - ) - - -def sample_inputs_pad_mask(): - for image_loader, fill, params in itertools.product(make_mask_loaders(sizes=["random"]), [None, 127], _PAD_PARAMS): - yield ArgsKwargs(image_loader, fill=fill, **params) - - -def reference_inputs_pad_mask(): - for image_loader, fill, params in itertools.product(make_image_loaders(extra_dims=[()]), [None, 127], _PAD_PARAMS): - yield ArgsKwargs(image_loader, fill=fill, **params) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.pad_image_tensor, - sample_inputs_fn=sample_inputs_pad_image_tensor, - reference_fn=pil_reference_wrapper(F.pad_image_pil), - reference_inputs_fn=reference_inputs_pad_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.pad_bounding_box, - sample_inputs_fn=sample_inputs_pad_bounding_box, - ), - KernelInfo( - F.pad_mask, - sample_inputs_fn=sample_inputs_pad_mask, - reference_fn=pil_reference_wrapper(F.pad_image_pil), - reference_inputs_fn=reference_inputs_pad_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - -_PERSPECTIVE_COEFFS = [ - [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], - [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], -] - - -def sample_inputs_perspective_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): - for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: - yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) - - -def reference_inputs_perspective_image_tensor(): - for image_loader, perspective_coeffs in itertools.product(make_image_loaders(extra_dims=[()]), _PERSPECTIVE_COEFFS): - # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it? - for fill in [None, 128.0, 128, [12.0 + c for c in range(image_loader.num_channels)]]: - yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=perspective_coeffs) - - -def sample_inputs_perspective_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, perspective_coeffs=_PERSPECTIVE_COEFFS[0] - ) - - -def sample_inputs_perspective_mask(): - for mask_loader in make_mask_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): - yield ArgsKwargs(mask_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) - - -def reference_inputs_perspective_mask(): - for mask_loader, perspective_coeffs in itertools.product( - make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS - ): - yield ArgsKwargs(mask_loader, perspective_coeffs=perspective_coeffs) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.perspective_image_tensor, - sample_inputs_fn=sample_inputs_perspective_image_tensor, - reference_fn=pil_reference_wrapper(F.perspective_image_pil), - reference_inputs_fn=reference_inputs_perspective_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.perspective_bounding_box, - sample_inputs_fn=sample_inputs_perspective_bounding_box, - ), - KernelInfo( - F.perspective_mask, - sample_inputs_fn=sample_inputs_perspective_mask, - reference_fn=pil_reference_wrapper(F.perspective_image_pil), - reference_inputs_fn=reference_inputs_perspective_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - - -def _get_elastic_displacement(image_size): - return torch.rand(1, *image_size, 2) - - -def sample_inputs_elastic_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): - displacement = _get_elastic_displacement(image_loader.image_size) - for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: - yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) - - -def reference_inputs_elastic_image_tensor(): - for image_loader, interpolation in itertools.product( - make_image_loaders(extra_dims=[()]), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BILINEAR, - F.InterpolationMode.BICUBIC, - ], - ): - displacement = _get_elastic_displacement(image_loader.image_size) - for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: - yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) - - -def sample_inputs_elastic_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - displacement = _get_elastic_displacement(bounding_box_loader.image_size) - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - displacement=displacement, - ) - - -def sample_inputs_elastic_mask(): - for mask_loader in make_mask_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): - displacement = _get_elastic_displacement(mask_loader.shape[-2:]) - yield ArgsKwargs(mask_loader, displacement=displacement) - - -def reference_inputs_elastic_mask(): - for mask_loader in make_mask_loaders(extra_dims=[()], num_objects=[1]): - displacement = _get_elastic_displacement(mask_loader.shape[-2:]) - yield ArgsKwargs(mask_loader, displacement=displacement) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.elastic_image_tensor, - sample_inputs_fn=sample_inputs_elastic_image_tensor, - reference_fn=pil_reference_wrapper(F.elastic_image_pil), - reference_inputs_fn=reference_inputs_elastic_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.elastic_bounding_box, - sample_inputs_fn=sample_inputs_elastic_bounding_box, - ), - KernelInfo( - F.elastic_mask, - sample_inputs_fn=sample_inputs_elastic_mask, - reference_fn=pil_reference_wrapper(F.elastic_image_pil), - reference_inputs_fn=reference_inputs_elastic_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - - -_CENTER_CROP_IMAGE_SIZES = [(16, 16), (7, 33), (31, 9)] -_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)] - - -def sample_inputs_center_crop_image_tensor(): - for image_loader, output_size in itertools.product( - make_image_loaders(sizes=_CENTER_CROP_IMAGE_SIZES), _CENTER_CROP_OUTPUT_SIZES - ): - yield ArgsKwargs(image_loader, output_size=output_size) - - -def reference_inputs_center_crop_image_tensor(): - for image_loader, output_size in itertools.product( - make_image_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES - ): - yield ArgsKwargs(image_loader, output_size=output_size) - - -def sample_inputs_center_crop_bounding_box(): - for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, - output_size=output_size, - ) - - -def sample_inputs_center_crop_mask(): - for mask_loader, output_size in itertools.product( - make_mask_loaders(sizes=_CENTER_CROP_IMAGE_SIZES), _CENTER_CROP_OUTPUT_SIZES - ): - yield ArgsKwargs(mask_loader, output_size=output_size) - - -def reference_inputs_center_crop_mask(): - for mask_loader, output_size in itertools.product( - make_mask_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES - ): - yield ArgsKwargs(mask_loader, output_size=output_size) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.center_crop_image_tensor, - sample_inputs_fn=sample_inputs_center_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.center_crop_image_pil), - reference_inputs_fn=reference_inputs_center_crop_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit("output_size"), - ], - ), - KernelInfo( - F.center_crop_bounding_box, - sample_inputs_fn=sample_inputs_center_crop_bounding_box, - skips=[ - skip_integer_size_jit("output_size"), - ], - ), - KernelInfo( - F.center_crop_mask, - sample_inputs_fn=sample_inputs_center_crop_mask, - reference_fn=pil_reference_wrapper(F.center_crop_image_pil), - reference_inputs_fn=reference_inputs_center_crop_mask, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit("output_size"), - ], - ), - ] -) - - -def sample_inputs_gaussian_blur_image_tensor(): - for image_loader, params in itertools.product( - make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ), - combinations_grid( - kernel_size=[(3, 3), [3, 3], 5], - sigma=[None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)], - ), - ): - yield ArgsKwargs(image_loader, **params) - - -KERNEL_INFOS.append( - KernelInfo( - F.gaussian_blur_image_tensor, - sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_python_scalar_arg_jit("kernel_size"), - skip_python_scalar_arg_jit("sigma"), - ], - ) -) - - -def sample_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), - dtypes=[torch.uint8], - ): - yield ArgsKwargs(image_loader) - - -def reference_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] - ): - yield ArgsKwargs(image_loader) - - -KERNEL_INFOS.append( - KernelInfo( - F.equalize_image_tensor, - kernel_name="equalize_image_tensor", - sample_inputs_fn=sample_inputs_equalize_image_tensor, - reference_fn=pil_reference_wrapper(F.equalize_image_pil), - reference_inputs_fn=reference_inputs_equalize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -def sample_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader) - - -def reference_inputs_invert_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] - ): - yield ArgsKwargs(image_loader) - - -KERNEL_INFOS.append( - KernelInfo( - F.invert_image_tensor, - kernel_name="invert_image_tensor", - sample_inputs_fn=sample_inputs_invert_image_tensor, - reference_fn=pil_reference_wrapper(F.invert_image_pil), - reference_inputs_fn=reference_inputs_invert_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -_POSTERIZE_BITS = [1, 4, 8] - - -def sample_inputs_posterize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] - ): - yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) - - -def reference_inputs_posterize_image_tensor(): - for image_loader, bits in itertools.product( - make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), - _POSTERIZE_BITS, - ): - yield ArgsKwargs(image_loader, bits=bits) - - -KERNEL_INFOS.append( - KernelInfo( - F.posterize_image_tensor, - kernel_name="posterize_image_tensor", - sample_inputs_fn=sample_inputs_posterize_image_tensor, - reference_fn=pil_reference_wrapper(F.posterize_image_pil), - reference_inputs_fn=reference_inputs_posterize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -def _get_solarize_thresholds(dtype): - for factor in [0.1, 0.5]: - max_value = get_max_value(dtype) - yield (float if dtype.is_floating_point else int)(max_value * factor) - - -def sample_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) - - -def reference_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] - ): - for threshold in _get_solarize_thresholds(image_loader.dtype): - yield ArgsKwargs(image_loader, threshold=threshold) - - -KERNEL_INFOS.append( - KernelInfo( - F.solarize_image_tensor, - kernel_name="solarize_image_tensor", - sample_inputs_fn=sample_inputs_solarize_image_tensor, - reference_fn=pil_reference_wrapper(F.solarize_image_pil), - reference_inputs_fn=reference_inputs_solarize_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -def sample_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader) - - -def reference_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()] - ): - yield ArgsKwargs(image_loader) - - -KERNEL_INFOS.append( - KernelInfo( - F.autocontrast_image_tensor, - kernel_name="autocontrast_image_tensor", - sample_inputs_fn=sample_inputs_autocontrast_image_tensor, - reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), - reference_inputs_fn=reference_inputs_autocontrast_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - -_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] - - -def sample_inputs_adjust_sharpness_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random", (2, 2)], - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), - # FIXME: kernel should support arbitrary batch sizes - extra_dims=[(), (4,)], - ): - yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) - - -def reference_inputs_adjust_sharpness_image_tensor(): - for image_loader, sharpness_factor in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_SHARPNESS_FACTORS, - ): - yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_sharpness_image_tensor, - kernel_name="adjust_sharpness_image_tensor", - sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), - reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -def sample_inputs_erase_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"]): - # FIXME: make the parameters more diverse - h, w = 6, 7 - v = torch.rand(image_loader.num_channels, h, w) - yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v) - - -KERNEL_INFOS.append( - KernelInfo( - F.erase_image_tensor, - kernel_name="erase_image_tensor", - sample_inputs_fn=sample_inputs_erase_image_tensor, - ) -) - -_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] - - -def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) - - -def reference_inputs_adjust_brightness_image_tensor(): - for image_loader, brightness_factor in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_BRIGHTNESS_FACTORS, - ): - yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_brightness_image_tensor, - kernel_name="adjust_brightness_image_tensor", - sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), - reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -_ADJUST_CONTRAST_FACTORS = [0.1, 0.5] - - -def sample_inputs_adjust_contrast_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) - - -def reference_inputs_adjust_contrast_image_tensor(): - for image_loader, contrast_factor in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_CONTRAST_FACTORS, - ): - yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_contrast_image_tensor, - kernel_name="adjust_contrast_image_tensor", - sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), - reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - -_ADJUST_GAMMA_GAMMAS_GAINS = [ - (0.5, 2.0), - (0.0, 1.0), -] - - -def sample_inputs_adjust_gamma_image_tensor(): - gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) - - -def reference_inputs_adjust_gamma_image_tensor(): - for image_loader, (gamma, gain) in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_GAMMA_GAMMAS_GAINS, - ): - yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_gamma_image_tensor, - kernel_name="adjust_gamma_image_tensor", - sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), - reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -_ADJUST_HUE_FACTORS = [-0.1, 0.5] - - -def sample_inputs_adjust_hue_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) - - -def reference_inputs_adjust_hue_image_tensor(): - for image_loader, hue_factor in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_HUE_FACTORS, - ): - yield ArgsKwargs(image_loader, hue_factor=hue_factor) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_hue_image_tensor, - kernel_name="adjust_hue_image_tensor", - sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), - reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - -_ADJUST_SATURATION_FACTORS = [0.1, 0.5] - - -def sample_inputs_adjust_saturation_image_tensor(): - for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) - ): - yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) - - -def reference_inputs_adjust_saturation_image_tensor(): - for image_loader, saturation_factor in itertools.product( - make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), - _ADJUST_SATURATION_FACTORS, - ): - yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) - - -KERNEL_INFOS.append( - KernelInfo( - F.adjust_saturation_image_tensor, - kernel_name="adjust_saturation_image_tensor", - sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), - reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ) -) - - -def sample_inputs_clamp_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size - ) - - -KERNEL_INFOS.append( - KernelInfo( - F.clamp_bounding_box, - sample_inputs_fn=sample_inputs_clamp_bounding_box, - ) -) - -_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]] - - -def _get_five_ten_crop_image_size(size): - if isinstance(size, int): - crop_height = crop_width = size - elif len(size) == 1: - crop_height = crop_width = size[0] - else: - crop_height, crop_width = size - return 2 * crop_height, 2 * crop_width - - -def sample_inputs_five_crop_image_tensor(): - for size in _FIVE_TEN_CROP_SIZES: - for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]): - yield ArgsKwargs(image_loader, size=size) - - -def reference_inputs_five_crop_image_tensor(): - for size in _FIVE_TEN_CROP_SIZES: - for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): - yield ArgsKwargs(image_loader, size=size) - - -def sample_inputs_ten_crop_image_tensor(): - for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): - for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]): - yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) - - -def reference_inputs_ten_crop_image_tensor(): - for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): - for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): - yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.five_crop_image_tensor, - sample_inputs_fn=sample_inputs_five_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.five_crop_image_pil), - reference_inputs_fn=reference_inputs_five_crop_image_tensor, - skips=[ - skip_integer_size_jit(), - Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), - ], - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - KernelInfo( - F.ten_crop_image_tensor, - sample_inputs_fn=sample_inputs_ten_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), - reference_inputs_fn=reference_inputs_ten_crop_image_tensor, - skips=[ - skip_integer_size_jit(), - Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), - ], - closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - ), - ] -) - -_NORMALIZE_MEANS_STDS = [ - ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), -] - - -def sample_inputs_normalize_image_tensor(): - for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), - _NORMALIZE_MEANS_STDS, - ): - yield ArgsKwargs(image_loader, mean=mean, std=std) - - -KERNEL_INFOS.append( - KernelInfo( - F.normalize_image_tensor, - kernel_name="normalize_image_tensor", - sample_inputs_fn=sample_inputs_normalize_image_tensor, - ) -) diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py deleted file mode 100644 index 283a30a3d85..00000000000 --- a/test/test_prototype_datasets_builtin.py +++ /dev/null @@ -1,220 +0,0 @@ -import functools -import io -import pickle -from pathlib import Path - -import pytest -import torch -from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks -from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair -from torch.utils.data import DataLoader -from torch.utils.data.graph import traverse_dps -from torch.utils.data.graph_settings import get_all_graph_pipes -from torchdata.datapipes.iter import ShardingFilter, Shuffler -from torchvision._utils import sequence_to_str -from torchvision.prototype import datasets, transforms -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE -from torchvision.prototype.features import Image, Label - -assert_samples_equal = functools.partial( - assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True -) - - -def extract_datapipes(dp): - return get_all_graph_pipes(traverse_dps(dp)) - - -@pytest.fixture(autouse=True) -def test_home(mocker, tmp_path): - mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) - mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path)) - yield tmp_path - - -def test_coverage(): - untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() - if untested_datasets: - raise AssertionError( - f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " - f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. " - f"Please add mock data to `test/builtin_dataset_mocks.py`." - ) - - -@pytest.mark.filterwarnings("error") -class TestCommon: - @pytest.mark.parametrize("name", datasets.list_datasets()) - def test_info(self, name): - try: - info = datasets.info(name) - except ValueError: - raise AssertionError("No info available.") from None - - if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())): - raise AssertionError("Info should be a dictionary with string keys.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_smoke(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - if not isinstance(dataset, datasets.utils.Dataset): - raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_sample(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - try: - sample = next(iter(dataset)) - except StopIteration: - raise AssertionError("Unable to draw any sample.") from None - except Exception as error: - raise AssertionError("Drawing a sample raised the error above.") from error - - if not isinstance(sample, dict): - raise AssertionError(f"Samples should be dictionaries, but got {type(sample)} instead.") - - if not sample: - raise AssertionError("Sample dictionary is empty.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_num_samples(self, dataset_mock, config): - dataset, mock_info = dataset_mock.load(config) - - assert len(list(dataset)) == mock_info["num_samples"] - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_vanilla_tensors(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} - if vanilla_tensors: - raise AssertionError( - f"The values of key(s) " - f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." - ) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_transformable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - next(iter(dataset.map(transforms.Identity()))) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_traversable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - traverse_dps(dataset) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - pickle.dumps(dataset) - - # This has to be a proper function, since lambda's or local functions - # cannot be pickled, but this is a requirement for the DataLoader with - # multiprocessing, i.e. num_workers > 0 - def _collate_fn(self, batch): - return batch - - @pytest.mark.parametrize("num_workers", [0, 1]) - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_data_loader(self, dataset_mock, config, num_workers): - dataset, _ = dataset_mock.load(config) - - dl = DataLoader( - dataset, - batch_size=2, - num_workers=num_workers, - collate_fn=self._collate_fn, - ) - - next(iter(dl)) - - # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also - # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 - # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @parametrize_dataset_mocks(DATASET_MOCKS) - @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) - def test_has_annotations(self, dataset_mock, config, annotation_dp_type): - dataset, _ = dataset_mock.load(config) - - if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): - raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_save_load(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - sample = next(iter(dataset)) - - with io.BytesIO() as buffer: - torch.save(sample, buffer) - buffer.seek(0) - assert_samples_equal(torch.load(buffer), sample) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_infinite_buffer_size(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - for dp in extract_datapipes(dataset): - if hasattr(dp, "buffer_size"): - # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is - # resolved - assert dp.buffer_size == INFINITE_BUFFER_SIZE - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_has_length(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - assert len(dataset) > 0 - - -@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) -class TestQMNIST: - def test_extra_label(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - sample = next(iter(dataset)) - for key, type in ( - ("nist_hsf_series", int), - ("nist_writer_id", int), - ("digit_index", int), - ("nist_label", int), - ("global_digit_index", int), - ("duplicate", bool), - ("unused", bool), - ): - assert key in sample and isinstance(sample[key], type) - - -@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) -class TestGTSRB: - def test_label_matches_path(self, dataset_mock, config): - # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. - # This test makes sure that they're both the same - if config["split"] != "train": - return - - dataset, _ = dataset_mock.load(config) - - for sample in dataset: - label_from_path = int(Path(sample["path"]).parent.name) - assert sample["label"] == label_from_path - - -@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) -class TestUSPS: - def test_sample_content(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - for sample in dataset: - assert "image" in sample - assert "label" in sample - - assert isinstance(sample["image"], Image) - assert isinstance(sample["label"], Label) - - assert sample["image"].shape == (1, 16, 16) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py deleted file mode 100644 index 2098ac736ac..00000000000 --- a/test/test_prototype_datasets_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -import gzip -import pathlib -import sys - -import numpy as np -import pytest -import torch -from datasets_utils import make_fake_flo_file, make_tar -from torchdata.datapipes.iter import FileOpener, TarArchiveLoader -from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.datasets.utils import _decompress -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import fromfile, read_flo - - -@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning") -@pytest.mark.parametrize( - ("np_dtype", "torch_dtype", "byte_order"), - [ - (">f4", torch.float32, "big"), - ("i8", torch.int64, "big"), - ("|u1", torch.uint8, sys.byteorder), - ], -) -@pytest.mark.parametrize("count", (-1, 2)) -@pytest.mark.parametrize("mode", ("rb", "r+b")) -def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode): - path = tmpdir / "data.bin" - rng = np.random.RandomState(0) - rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path) - - for count_ in (-1, count // 2): - expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:])) - - with open(path, mode) as file: - actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_) - - torch.testing.assert_close(actual, expected) - - -def test_read_flo(tmpdir): - path = tmpdir / "test.flo" - make_fake_flo_file(3, 4, path) - - with open(path, "rb") as file: - actual = read_flo(file) - - expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False)) - - torch.testing.assert_close(actual, expected) - - -class TestOnlineResource: - class DummyResource(OnlineResource): - def __init__(self, download_fn=None, **kwargs): - super().__init__(**kwargs) - self._download_fn = download_fn - - def _download(self, root): - if self._download_fn is None: - raise pytest.UsageError( - "`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`." - ) - - return self._download_fn(self, root) - - def _make_file(self, root, *, content, name="file.txt"): - file = root / name - with open(file, "w") as fh: - fh.write(content) - - return file - - def _make_folder(self, root, *, name="folder"): - folder = root / name - subfolder = folder / "subfolder" - subfolder.mkdir(parents=True) - - files = {} - for idx, root in enumerate([folder, folder, subfolder]): - content = f"sentinel{idx}" - file = self._make_file(root, name=f"file{idx}.txt", content=content) - files[str(file)] = content - - return folder, files - - def _make_tar(self, root, *, name="archive.tar", remove=True): - folder, files = self._make_folder(root, name=name.split(".")[0]) - archive = make_tar(root, name, folder, remove=remove) - files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()} - return archive, files - - def test_load_file(self, tmp_path): - content = "sentinel" - file = self._make_file(tmp_path, content=content) - - resource = self.DummyResource(file_name=file.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, FileOpener) - - data = list(dp) - assert len(data) == 1 - - path, buffer = data[0] - assert path == str(file) - assert buffer.read().decode() == content - - def test_load_folder(self, tmp_path): - folder, files = self._make_folder(tmp_path) - - resource = self.DummyResource(file_name=folder.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, FileOpener) - assert {path: buffer.read().decode() for path, buffer in dp} == files - - def test_load_archive(self, tmp_path): - archive, files = self._make_tar(tmp_path) - - resource = self.DummyResource(file_name=archive.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, TarArchiveLoader) - assert {path: buffer.read().decode() for path, buffer in dp} == files - - def test_priority_decompressed_gt_raw(self, tmp_path): - # We don't need to actually compress here. Adding the suffix is sufficient - self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz") - file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt") - - resource = self.DummyResource(file_name=file.name) - - dp = resource.load(tmp_path) - path, buffer = next(iter(dp)) - - assert path == str(file) - assert buffer.read().decode() == "decompressed_sentinel" - - def test_priority_extracted_gt_decompressed(self, tmp_path): - archive, _ = self._make_tar(tmp_path, remove=False) - - resource = self.DummyResource(file_name=archive.name) - - dp = resource.load(tmp_path) - # If the archive had been selected, this would be a `TarArchiveReader` - assert isinstance(dp, FileOpener) - - def test_download(self, tmp_path): - download_fn_was_called = False - - def download_fn(resource, root): - nonlocal download_fn_was_called - download_fn_was_called = True - - return self._make_file(root, content="_", name=resource.file_name) - - resource = self.DummyResource( - file_name="file.txt", - download_fn=download_fn, - ) - - resource.load(tmp_path) - - assert download_fn_was_called, "`download_fn()` was never called" - - # This tests the `"decompress"` literal as well as a custom callable - @pytest.mark.parametrize( - "preprocess", - [ - "decompress", - lambda path: _decompress(str(path), remove_finished=True), - ], - ) - def test_preprocess_decompress(self, tmp_path, preprocess): - file_name = "file.txt.gz" - content = "sentinel" - - def download_fn(resource, root): - file = root / resource.file_name - with gzip.open(file, "wb") as fh: - fh.write(content.encode()) - return file - - resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn) - - dp = resource.load(tmp_path) - data = list(dp) - assert len(data) == 1 - - path, buffer = data[0] - assert path == str(tmp_path / file_name).replace(".gz", "") - assert buffer.read().decode() == content - - def test_preprocess_extract(self, tmp_path): - files = None - - def download_fn(resource, root): - nonlocal files - archive, files = self._make_tar(root, name=resource.file_name) - return archive - - resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn) - - dp = resource.load(tmp_path) - assert files is not None, "`download_fn()` was never called" - assert isinstance(dp, FileOpener) - - actual = {path: buffer.read().decode() for path, buffer in dp} - expected = { - path.replace(resource.file_name, resource.file_name.split(".")[0]): content - for path, content in files.items() - } - assert actual == expected - - def test_preprocess_only_after_download(self, tmp_path): - file = self._make_file(tmp_path, content="_") - - def preprocess(path): - raise AssertionError("`preprocess` was called although the file was already present.") - - resource = self.DummyResource( - file_name=file.name, - preprocess=preprocess, - ) - - resource.load(tmp_path) - - -class TestHttpResource: - def test_resolve_to_http(self, mocker): - file_name = "data.tar" - original_url = f"http://downloads.pytorch.org/{file_name}" - - redirected_url = original_url.replace("http", "https") - - sha256_sentinel = "sha256_sentinel" - - def preprocess_sentinel(path): - return path - - original_resource = HttpResource( - original_url, - sha256=sha256_sentinel, - preprocess=preprocess_sentinel, - ) - - mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) - redirected_resource = original_resource.resolve() - - assert isinstance(redirected_resource, HttpResource) - assert redirected_resource.url == redirected_url - assert redirected_resource.file_name == file_name - assert redirected_resource.sha256 == sha256_sentinel - assert redirected_resource._preprocess is preprocess_sentinel - - def test_resolve_to_gdrive(self, mocker): - file_name = "data.tar" - original_url = f"http://downloads.pytorch.org/{file_name}" - - id_sentinel = "id-sentinel" - redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view" - - sha256_sentinel = "sha256_sentinel" - - def preprocess_sentinel(path): - return path - - original_resource = HttpResource( - original_url, - sha256=sha256_sentinel, - preprocess=preprocess_sentinel, - ) - - mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) - redirected_resource = original_resource.resolve() - - assert isinstance(redirected_resource, GDriveResource) - assert redirected_resource.id == id_sentinel - assert redirected_resource.file_name == file_name - assert redirected_resource.sha256 == sha256_sentinel - assert redirected_resource._preprocess is preprocess_sentinel - - -def test_missing_dependency_error(): - class DummyDataset(Dataset): - def __init__(self): - super().__init__(root="root", dependencies=("fake_dependency",)) - - def _resources(self): - pass - - def _datapipe(self, resource_dps): - pass - - def __len__(self): - pass - - with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): - DummyDataset() diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py deleted file mode 100644 index 2701dd66be0..00000000000 --- a/test/test_prototype_features.py +++ /dev/null @@ -1,113 +0,0 @@ -import pytest -import torch -from torchvision.prototype import features - - -def test_isinstance(): - assert isinstance( - features.Label([0, 1, 0], categories=["foo", "bar"]), - torch.Tensor, - ) - - -def test_wrapping_no_copy(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - assert label.data_ptr() == tensor.data_ptr() - - -def test_to_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - label_to = label.to(torch.int32) - - assert type(label_to) is features.Label - assert label_to.dtype is torch.int32 - assert label_to.categories is label.categories - - -def test_to_feature_reference(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) - - tensor_to = tensor.to(label) - - assert type(tensor_to) is torch.Tensor - assert tensor_to.dtype is torch.int32 - - -def test_clone_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - label_clone = label.clone() - - assert type(label_clone) is features.Label - assert label_clone.data_ptr() != label.data_ptr() - assert label_clone.categories is label.categories - - -def test_requires_grad__wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.float32) - label = features.Label(tensor, categories=["foo", "bar"]) - - assert not label.requires_grad - - label_requires_grad = label.requires_grad_(True) - - assert type(label_requires_grad) is features.Label - assert label.requires_grad - assert label_requires_grad.requires_grad - - -def test_other_op_no_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - # any operation besides .to() and .clone() will do here - output = label * 2 - - assert type(output) is torch.Tensor - - -@pytest.mark.parametrize( - "op", - [ - lambda t: t.numpy(), - lambda t: t.tolist(), - lambda t: t.max(dim=-1), - ], -) -def test_no_tensor_output_op_no_wrapping(op): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - output = op(label) - - assert type(output) is not features.Label - - -def test_inplace_op_no_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - output = label.add_(0) - - assert type(output) is torch.Tensor - assert type(label) is features.Label - - -def test_new_like(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) - - # any operation besides .to() and .clone() will do here - output = label * 2 - - label_new = features.Label.new_like(label, output) - - assert type(label_new) is features.Label - assert label_new.data_ptr() == output.data_ptr() - assert label_new.categories is label.categories diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py deleted file mode 100644 index 6d9f22c1543..00000000000 --- a/test/test_prototype_models.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -import test_models as TM -import torch -from common_utils import cpu_and_gpu, set_rng_seed -from torchvision.prototype import models - - -@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,)) -@pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_raft_stereo(model_fn, model_mode, dev): - # A simple test to make sure the model can do forward pass and jit scriptable - set_rng_seed(0) - - # Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output - # get the idea from test_models.test_raft - corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2) - corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2) - model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev) - - if model_mode == "scripted": - model = torch.jit.script(model) - - img1 = torch.rand(1, 3, 64, 64).to(dev) - img2 = torch.rand(1, 3, 64, 64).to(dev) - num_iters = 3 - - preds = model(img1, img2, num_iters=num_iters) - depth_pred = preds[-1] - - assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters" - - assert depth_pred.shape == torch.Size( - [1, 1, 64, 64] - ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}" - - # Test against expected file output - TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,)) -@pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_crestereo(model_fn, model_mode, dev): - set_rng_seed(0) - - model = model_fn().eval().to(dev) - - if model_mode == "scripted": - model = torch.jit.script(model) - - img1 = torch.rand(1, 3, 64, 64).to(dev) - img2 = torch.rand(1, 3, 64, 64).to(dev) - iterations = 3 - - preds = model(img1, img2, flow_init=None, num_iters=iterations) - disparity_pred = preds[-1] - - # all the pyramid levels except the highest res make only half the number of iterations - expected_iterations = (iterations // 2) * (len(model.resolutions) - 1) - expected_iterations += iterations - assert ( - len(preds) == expected_iterations - ), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels" - - assert disparity_pred.shape == torch.Size( - [1, 2, 64, 64] - ), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}" - - assert all( - d.shape == torch.Size([1, 2, 64, 64]) for d in preds - ), "All predicted disparities are expected to have the same shape" - - # test a backward pass with a dummy loss as well - preds = torch.stack(preds, dim=0) - targets = torch.ones_like(preds, requires_grad=False) - loss = torch.nn.functional.mse_loss(preds, targets) - - try: - loss.backward() - except Exception as e: - assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}" - - TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py deleted file mode 100644 index 9734a5dc30a..00000000000 --- a/test/test_prototype_transforms.py +++ /dev/null @@ -1,1780 +0,0 @@ -import itertools - -import numpy as np - -import PIL.Image - -import pytest -import torch -from common_utils import assert_equal, cpu_and_gpu -from prototype_common_utils import ( - make_bounding_box, - make_bounding_boxes, - make_detection_mask, - make_image, - make_images, - make_label, - make_masks, - make_one_hot_labels, - make_segmentation_mask, -) -from torchvision.ops.boxes import box_iou -from torchvision.prototype import features, transforms -from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image - - -def make_vanilla_tensor_images(*args, **kwargs): - for image in make_images(*args, **kwargs): - if image.ndim > 3: - continue - yield image.data - - -def make_pil_images(*args, **kwargs): - for image in make_vanilla_tensor_images(*args, **kwargs): - yield to_pil_image(image) - - -def make_vanilla_tensor_bounding_boxes(*args, **kwargs): - for bounding_box in make_bounding_boxes(*args, **kwargs): - yield bounding_box.data - - -def parametrize(transforms_with_inputs): - return pytest.mark.parametrize( - ("transform", "input"), - [ - pytest.param( - transform, - input, - id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", - ) - for transform, inputs in transforms_with_inputs - for idx, input in enumerate(inputs) - ], - ) - - -def parametrize_from_transforms(*transforms): - transforms_with_inputs = [] - for transform in transforms: - for creation_fn in [ - make_images, - make_bounding_boxes, - make_one_hot_labels, - make_vanilla_tensor_images, - make_pil_images, - make_masks, - ]: - inputs = list(creation_fn()) - try: - output = transform(inputs[0]) - except Exception: - continue - else: - if output is inputs[0]: - continue - - transforms_with_inputs.append((transform, inputs)) - - return parametrize(transforms_with_inputs) - - -class TestSmoke: - @parametrize_from_transforms( - transforms.RandomErasing(p=1.0), - transforms.Resize([16, 16]), - transforms.CenterCrop([16, 16]), - transforms.ConvertImageDtype(), - transforms.RandomHorizontalFlip(), - transforms.Pad(5), - transforms.RandomZoomOut(), - transforms.RandomRotation(degrees=(-45, 45)), - transforms.RandomAffine(degrees=(-45, 45)), - transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), - # TODO: Something wrong with input data setup. Let's fix that - # transforms.RandomEqualize(), - # transforms.RandomInvert(), - # transforms.RandomPosterize(bits=4), - # transforms.RandomSolarize(threshold=0.5), - # transforms.RandomAdjustSharpness(sharpness_factor=0.5), - ) - def test_common(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transform, - [ - dict( - image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float), - one_hot_label=features.OneHotLabel.new_like( - one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float - ), - ) - for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels()) - ], - ) - for transform in [ - transforms.RandomMixup(alpha=1.0), - transforms.RandomCutmix(alpha=1.0), - ] - ] - ) - def test_mixup_cutmix(self, transform, input): - transform(input) - - # add other data that should bypass and wont raise any error - input_copy = dict(input) - input_copy["path"] = "/path/to/somewhere" - input_copy["num"] = 1234 - transform(input_copy) - - # Check if we raise an error if sample contains bbox or mask or label - err_msg = "does not support PIL images, bounding boxes, masks and plain labels" - input_copy = dict(input) - for unsup_data in [ - make_label(), - make_bounding_box(format="XYXY"), - make_detection_mask(), - make_segmentation_mask(), - ]: - input_copy["unsupported"] = unsup_data - with pytest.raises(TypeError, match=err_msg): - transform(input_copy) - - @parametrize( - [ - ( - transform, - itertools.chain.from_iterable( - fn( - color_spaces=[ - features.ColorSpace.GRAY, - features.ColorSpace.RGB, - ], - dtypes=[torch.uint8], - extra_dims=[(4,)], - ) - for fn in [ - make_images, - make_vanilla_tensor_images, - make_pil_images, - ] - ), - ) - for transform in ( - transforms.RandAugment(), - transforms.TrivialAugmentWide(), - transforms.AutoAugment(), - transforms.AugMix(), - ) - ] - ) - def test_auto_augment(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), - itertools.chain.from_iterable( - fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]) - for fn in [ - make_images, - make_vanilla_tensor_images, - ] - ), - ), - ] - ) - def test_normalize(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transforms.RandomResizedCrop([16, 16]), - itertools.chain( - make_images(extra_dims=[(4,)]), - make_vanilla_tensor_images(), - make_pil_images(), - ), - ) - ] - ) - def test_random_resized_crop(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space), - itertools.chain.from_iterable( - [ - fn(color_spaces=[old_color_space]) - for fn in ( - make_images, - make_vanilla_tensor_images, - make_pil_images, - ) - ] - ), - ) - for old_color_space, new_color_space in itertools.product( - [ - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, - ], - repeat=2, - ) - ] - ) - def test_convert_color_space(self, transform, input): - transform(input) - - def test_convert_color_space_unsupported_types(self): - transform = transforms.ConvertColorSpace( - color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY - ) - - for inpt in [make_bounding_box(format="XYXY"), make_masks()]: - output = transform(inpt) - assert output is inpt - - -@pytest.mark.parametrize("p", [0.0, 1.0]) -class TestRandomHorizontalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) - expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) - - return input, expected if p == 1 else input - - def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(input) - - assert_equal(expected, actual) - - def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(to_pil_image(input)) - - assert_equal(expected, pil_to_tensor(actual)) - - def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(features.Image(input)) - - assert_equal(features.Image(expected), actual) - - def test_features_mask(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(features.Mask(input)) - - assert_equal(features.Mask(expected), actual) - - def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(input) - - expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) - assert_equal(expected, actual) - assert actual.format == expected.format - assert actual.image_size == expected.image_size - - -@pytest.mark.parametrize("p", [0.0, 1.0]) -class TestRandomVerticalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) - expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype) - - return input, expected if p == 1 else input - - def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomVerticalFlip(p=p) - - actual = transform(input) - - assert_equal(expected, actual) - - def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) - transform = transforms.RandomVerticalFlip(p=p) - - actual = transform(to_pil_image(input)) - - assert_equal(expected, pil_to_tensor(actual)) - - def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomVerticalFlip(p=p) - - actual = transform(features.Image(input)) - - assert_equal(features.Image(expected), actual) - - def test_features_mask(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomVerticalFlip(p=p) - - actual = transform(features.Mask(input)) - - assert_equal(features.Mask(expected), actual) - - def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) - transform = transforms.RandomVerticalFlip(p=p) - - actual = transform(input) - - expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) - assert_equal(expected, actual) - assert actual.format == expected.format - assert actual.image_size == expected.image_size - - -class TestPad: - def test_assertions(self): - with pytest.raises(TypeError, match="Got inappropriate padding arg"): - transforms.Pad("abc") - - with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): - transforms.Pad([-0.7, 0, 0.7]) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.Pad(12, fill="abc") - - with pytest.raises(ValueError, match="Padding mode should be either"): - transforms.Pad(12, padding_mode="abc") - - @pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, fill, padding_mode, mocker): - transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) - - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - inpt = mocker.MagicMock(spec=features.Image) - _ = transform(inpt) - - fill = transforms.functional._geometry._convert_fill_arg(fill) - if isinstance(padding, tuple): - padding = list(padding) - fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - - @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.Pad(1, fill=fill, padding_mode="constant") - - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - image = features.Image(torch.rand(3, 32, 32)) - mask = features.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - _ = transform(inpt) - - if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) - calls = [ - mocker.call(image, padding=1, fill=fill, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), - ] - else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), - ] - fn.assert_has_calls(calls) - - -class TestRandomZoomOut: - def test_assertions(self): - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomZoomOut(fill="abc") - - with pytest.raises(TypeError, match="should be a sequence of length"): - transforms.RandomZoomOut(0, side_range=0) - - with pytest.raises(ValueError, match="Invalid canvas side range"): - transforms.RandomZoomOut(0, side_range=[4.0, 1.0]) - - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__get_params(self, fill, side_range, mocker): - transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - - image = mocker.MagicMock(spec=features.Image) - h, w = image.image_size = (24, 32) - - params = transform._get_params(image) - - assert len(params["padding"]) == 4 - assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w - assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h - assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w - assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h - - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__transform(self, fill, side_range, mocker): - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (24, 32) - - transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) - - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params(inpt) - - fill = transforms.functional._geometry._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill) - - @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.RandomZoomOut(fill=fill, p=1.0) - - fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - image = features.Image(torch.rand(3, 32, 32)) - mask = features.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params(inpt) - - if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) - calls = [ - mocker.call(image, **params, fill=fill), - mocker.call(mask, **params, fill=fill), - ] - else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, **params, fill=fill_img), - mocker.call(mask, **params, fill=fill_mask), - ] - fn.assert_has_calls(calls) - - -class TestRandomRotation: - def test_assertions(self): - with pytest.raises(ValueError, match="is a single number, it must be positive"): - transforms.RandomRotation(-0.7) - - for d in [[-0.7], [-0.7, 0, 0.7]]: - with pytest.raises(ValueError, match="degrees should be a sequence of length 2"): - transforms.RandomRotation(d) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomRotation(12, fill="abc") - - with pytest.raises(TypeError, match="center should be a sequence of length"): - transforms.RandomRotation(12, center=12) - - with pytest.raises(ValueError, match="center should be a sequence of length"): - transforms.RandomRotation(12, center=[1, 2, 3]) - - def test__get_params(self): - angle_bound = 34 - transform = transforms.RandomRotation(angle_bound) - - params = transform._get_params(None) - assert -angle_bound <= params["angle"] <= angle_bound - - angle_bounds = [12, 34] - transform = transforms.RandomRotation(angle_bounds) - - params = transform._get_params(None) - assert angle_bounds[0] <= params["angle"] <= angle_bounds[1] - - @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) - @pytest.mark.parametrize("expand", [False, True]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("center", [None, [2.0, 3.0]]) - def test__transform(self, degrees, expand, fill, center, mocker): - interpolation = InterpolationMode.BILINEAR - transform = transforms.RandomRotation( - degrees, interpolation=interpolation, expand=expand, fill=fill, center=center - ) - - if isinstance(degrees, (tuple, list)): - assert transform.degrees == [float(degrees[0]), float(degrees[1])] - else: - assert transform.degrees == [float(-degrees), float(degrees)] - - fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") - inpt = mocker.MagicMock(spec=features.Image) - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params(inpt) - - fill = transforms.functional._geometry._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) - - @pytest.mark.parametrize("angle", [34, -87]) - @pytest.mark.parametrize("expand", [False, True]) - def test_boundingbox_image_size(self, angle, expand): - # Specific test for BoundingBox.rotate - bbox = features.BoundingBox( - torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32) - ) - img = features.Image(torch.rand(1, 3, 32, 32)) - - out_img = img.rotate(angle, expand=expand) - out_bbox = bbox.rotate(angle, expand=expand) - - assert out_img.image_size == out_bbox.image_size - - -class TestRandomAffine: - def test_assertions(self): - with pytest.raises(ValueError, match="is a single number, it must be positive"): - transforms.RandomAffine(-0.7) - - for d in [[-0.7], [-0.7, 0, 0.7]]: - with pytest.raises(ValueError, match="degrees should be a sequence of length 2"): - transforms.RandomAffine(d) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomAffine(12, fill="abc") - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomAffine(12, fill="abc") - - for kwargs in [ - {"center": 12}, - {"translate": 12}, - {"scale": 12}, - ]: - with pytest.raises(TypeError, match="should be a sequence of length"): - transforms.RandomAffine(12, **kwargs) - - for kwargs in [{"center": [1, 2, 3]}, {"translate": [1, 2, 3]}, {"scale": [1, 2, 3]}]: - with pytest.raises(ValueError, match="should be a sequence of length"): - transforms.RandomAffine(12, **kwargs) - - with pytest.raises(ValueError, match="translation values should be between 0 and 1"): - transforms.RandomAffine(12, translate=[-1.0, 2.0]) - - with pytest.raises(ValueError, match="scale values should be positive"): - transforms.RandomAffine(12, scale=[-1.0, 2.0]) - - with pytest.raises(ValueError, match="is a single number, it must be positive"): - transforms.RandomAffine(12, shear=-10) - - for s in [[-0.7], [-0.7, 0, 0.7]]: - with pytest.raises(ValueError, match="shear should be a sequence of length 2"): - transforms.RandomAffine(12, shear=s) - - @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) - @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) - @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) - @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - def test__get_params(self, degrees, translate, scale, shear, mocker): - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - h, w = image.image_size - - transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) - params = transform._get_params(image) - - if not isinstance(degrees, (list, tuple)): - assert -degrees <= params["angle"] <= degrees - else: - assert degrees[0] <= params["angle"] <= degrees[1] - - if translate is not None: - w_max = int(round(translate[0] * w)) - h_max = int(round(translate[1] * h)) - assert -w_max <= params["translate"][0] <= w_max - assert -h_max <= params["translate"][1] <= h_max - else: - assert params["translate"] == (0, 0) - - if scale is not None: - assert scale[0] <= params["scale"] <= scale[1] - else: - assert params["scale"] == 1.0 - - if shear is not None: - if isinstance(shear, float): - assert -shear <= params["shear"][0] <= shear - assert params["shear"][1] == 0.0 - elif len(shear) == 2: - assert shear[0] <= params["shear"][0] <= shear[1] - assert params["shear"][1] == 0.0 - else: - assert shear[0] <= params["shear"][0] <= shear[1] - assert shear[2] <= params["shear"][1] <= shear[3] - else: - assert params["shear"] == (0, 0) - - @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) - @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) - @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) - @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("center", [None, [2.0, 3.0]]) - def test__transform(self, degrees, translate, scale, shear, fill, center, mocker): - interpolation = InterpolationMode.BILINEAR - transform = transforms.RandomAffine( - degrees, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - - if isinstance(degrees, (tuple, list)): - assert transform.degrees == [float(degrees[0]), float(degrees[1])] - else: - assert transform.degrees == [float(-degrees), float(degrees)] - - fn = mocker.patch("torchvision.prototype.transforms.functional.affine") - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (24, 32) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params(inpt) - - fill = transforms.functional._geometry._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) - - -class TestRandomCrop: - def test_assertions(self): - with pytest.raises(ValueError, match="Please provide only two dimensions"): - transforms.RandomCrop([10, 12, 14]) - - with pytest.raises(TypeError, match="Got inappropriate padding arg"): - transforms.RandomCrop([10, 12], padding="abc") - - with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): - transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomCrop([10, 12], padding=1, fill="abc") - - with pytest.raises(ValueError, match="Padding mode should be either"): - transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") - - @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) - @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) - def test__get_params(self, padding, pad_if_needed, size, mocker): - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - h, w = image.image_size - - transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) - params = transform._get_params(image) - - if padding is not None: - if isinstance(padding, int): - pad_top = pad_bottom = pad_left = pad_right = padding - elif isinstance(padding, list) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - elif isinstance(padding, list) and len(padding) == 4: - pad_left, pad_top, pad_right, pad_bottom = padding - - h += pad_top + pad_bottom - w += pad_left + pad_right - else: - pad_left = pad_right = pad_top = pad_bottom = 0 - - if pad_if_needed: - if w < size[1]: - diff = size[1] - w - pad_left += diff - pad_right += diff - w += 2 * diff - if h < size[0]: - diff = size[0] - h - pad_top += diff - pad_bottom += diff - h += 2 * diff - - padding = [pad_left, pad_top, pad_right, pad_bottom] - - assert 0 <= params["top"] <= h - size[0] + 1 - assert 0 <= params["left"] <= w - size[1] + 1 - assert params["height"] == size[0] - assert params["width"] == size[1] - assert params["needs_pad"] is any(padding) - assert params["padding"] == padding - - @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) - @pytest.mark.parametrize("pad_if_needed", [False, True]) - @pytest.mark.parametrize("fill", [False, True]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): - output_size = [10, 12] - transform = transforms.RandomCrop( - output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode - ) - - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (32, 32) - - expected = mocker.MagicMock(spec=features.Image) - expected.num_channels = 3 - if isinstance(padding, int): - expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding) - elif isinstance(padding, list): - expected.image_size = ( - inpt.image_size[0] + sum(padding[0::2]), - inpt.image_size[1] + sum(padding[1::2]), - ) - else: - expected.image_size = inpt.image_size - _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) - fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params(inpt) - if padding is None and not pad_if_needed: - fn_crop.assert_called_once_with( - inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif not pad_if_needed: - fn_crop.assert_called_once_with( - expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif padding is None: - # vfdev-5: I do not know how to mock and test this case - pass - else: - # vfdev-5: I do not know how to mock and test this case - pass - - -class TestGaussianBlur: - def test_assertions(self): - with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"): - transforms.GaussianBlur([10, 12, 14]) - - with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"): - transforms.GaussianBlur(4) - - with pytest.raises( - TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats." - ): - transforms.GaussianBlur(3, sigma=[1, 2, 3]) - - with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"): - transforms.GaussianBlur(3, sigma=-1.0) - - with pytest.raises(ValueError, match="sigma values should be positive and of the form"): - transforms.GaussianBlur(3, sigma=[2.0, 1.0]) - - @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]]) - def test__get_params(self, sigma): - transform = transforms.GaussianBlur(3, sigma=sigma) - params = transform._get_params(None) - - if isinstance(sigma, float): - assert params["sigma"][0] == params["sigma"][1] == 10 - else: - assert sigma[0] <= params["sigma"][0] <= sigma[1] - assert sigma[0] <= params["sigma"][1] <= sigma[1] - - @pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]]) - def test__transform(self, kernel_size, sigma, mocker): - transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma) - - if isinstance(kernel_size, (tuple, list)): - assert transform.kernel_size == kernel_size - else: - kernel_size = (kernel_size, kernel_size) - assert transform.kernel_size == kernel_size - - if isinstance(sigma, (tuple, list)): - assert transform.sigma == sigma - else: - assert transform.sigma == [sigma, sigma] - - fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (24, 32) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params(inpt) - - fn.assert_called_once_with(inpt, kernel_size, **params) - - -class TestRandomColorOp: - @pytest.mark.parametrize("p", [0.0, 1.0]) - @pytest.mark.parametrize( - "transform_cls, func_op_name, kwargs", - [ - (transforms.RandomEqualize, "equalize", {}), - (transforms.RandomInvert, "invert", {}), - (transforms.RandomAutocontrast, "autocontrast", {}), - (transforms.RandomPosterize, "posterize", {"bits": 4}), - (transforms.RandomSolarize, "solarize", {"threshold": 0.5}), - (transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}), - ], - ) - def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): - transform = transform_cls(p=p, **kwargs) - - fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") - inpt = mocker.MagicMock(spec=features.Image) - _ = transform(inpt) - if p > 0.0: - fn.assert_called_once_with(inpt, **kwargs) - else: - assert fn.call_count == 0 - - -class TestRandomPerspective: - def test_assertions(self): - with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"): - transforms.RandomPerspective(distortion_scale=-1.0) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomPerspective(0.5, fill="abc") - - def test__get_params(self, mocker): - dscale = 0.5 - transform = transforms.RandomPerspective(dscale) - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - - params = transform._get_params(image) - - h, w = image.image_size - assert "perspective_coeffs" in params - assert len(params["perspective_coeffs"]) == 8 - - @pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) - def test__transform(self, distortion_scale, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) - - fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (24, 32) - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params(inpt) - - fill = transforms.functional._geometry._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) - - -class TestElasticTransform: - def test_assertions(self): - - with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): - transforms.ElasticTransform({}) - - with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): - transforms.ElasticTransform([1.0, 2.0, 3.0]) - - with pytest.raises(ValueError, match="alpha should be a sequence of floats"): - transforms.ElasticTransform([1, 2]) - - with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"): - transforms.ElasticTransform(1.0, {}) - - with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): - transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) - - with pytest.raises(ValueError, match="sigma should be a sequence of floats"): - transforms.ElasticTransform(1.0, [1, 2]) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.ElasticTransform(1.0, 2.0, fill="abc") - - def test__get_params(self, mocker): - alpha = 2.0 - sigma = 3.0 - transform = transforms.ElasticTransform(alpha, sigma) - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - - params = transform._get_params(image) - - h, w = image.image_size - displacement = params["displacement"] - assert displacement.shape == (1, h, w, 2) - assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() - assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() - - @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) - def test__transform(self, alpha, sigma, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) - - if isinstance(alpha, float): - assert transform.alpha == [alpha, alpha] - else: - assert transform.alpha == alpha - - if isinstance(sigma, float): - assert transform.sigma == [sigma, sigma] - else: - assert transform.sigma == sigma - - fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") - inpt = mocker.MagicMock(spec=features.Image) - inpt.num_channels = 3 - inpt.image_size = (24, 32) - - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock() - _ = transform(inpt) - params = transform._get_params(inpt) - fill = transforms.functional._geometry._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) - - -class TestRandomErasing: - def test_assertions(self, mocker): - with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): - transforms.RandomErasing(value={}) - - with pytest.raises(ValueError, match="If value is str, it should be 'random'"): - transforms.RandomErasing(value="abc") - - with pytest.raises(TypeError, match="Scale should be a sequence"): - transforms.RandomErasing(scale=123) - - with pytest.raises(TypeError, match="Ratio should be a sequence"): - transforms.RandomErasing(ratio=123) - - with pytest.raises(ValueError, match="Scale should be between 0 and 1"): - transforms.RandomErasing(scale=[-1, 2]) - - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - - transform = transforms.RandomErasing(value=[1, 2, 3, 4]) - - with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): - transform._get_params(image) - - @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) - def test__get_params(self, value, mocker): - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - - transform = transforms.RandomErasing(value=value) - params = transform._get_params(image) - - v = params["v"] - h, w = params["h"], params["w"] - i, j = params["i"], params["j"] - assert isinstance(v, torch.Tensor) - if value == "random": - assert v.shape == (image.num_channels, h, w) - elif isinstance(value, (int, float)): - assert v.shape == (1, 1, 1) - elif isinstance(value, (list, tuple)): - assert v.shape == (image.num_channels, 1, 1) - - assert 0 <= i <= image.image_size[0] - h - assert 0 <= j <= image.image_size[1] - w - - @pytest.mark.parametrize("p", [0, 1]) - def test__transform(self, mocker, p): - transform = transforms.RandomErasing(p=p) - transform._transformed_types = (mocker.MagicMock,) - - i_sentinel = mocker.MagicMock() - j_sentinel = mocker.MagicMock() - h_sentinel = mocker.MagicMock() - w_sentinel = mocker.MagicMock() - v_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._augment.RandomErasing._get_params", - return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase") - output = transform(inpt_sentinel) - - if p: - mock.assert_called_once_with( - inpt_sentinel, - i=i_sentinel, - j=j_sentinel, - h=h_sentinel, - w=w_sentinel, - v=v_sentinel, - inplace=transform.inplace, - ) - else: - mock.assert_not_called() - assert output is inpt_sentinel - - -class TestTransform: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], - ) - def test_check_transformed_types(self, inpt_type, mocker): - # This test ensures that we correctly handle which types to transform and which to bypass - t = transforms.Transform() - inpt = mocker.MagicMock(spec=inpt_type) - - if inpt_type in (np.ndarray, str, int): - output = t(inpt) - assert output is inpt - else: - with pytest.raises(NotImplementedError): - t(inpt) - - -class TestToImageTensor: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch( - "torchvision.prototype.transforms.functional.to_image_tensor", - return_value=torch.rand(1, 3, 8, 8), - ) - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToImageTensor() - transform(inpt) - if inpt_type in (features.BoundingBox, features.Image, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt) - - -class TestToImagePIL: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToImagePIL() - transform(inpt) - if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt, mode=transform.mode) - - -class TestToPILImage: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToPILImage() - transform(inpt) - if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt, mode=transform.mode) - - -class TestToTensor: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.transforms.functional.to_tensor") - - inpt = mocker.MagicMock(spec=inpt_type) - with pytest.warns(UserWarning, match="deprecated and will be removed"): - transform = transforms.ToTensor() - transform(inpt) - if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt) - - -class TestContainers: - @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) - def test_assertions(self, transform_cls): - with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): - transform_cls(transforms.RandomCrop(28)) - - @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) - @pytest.mark.parametrize( - "trfms", - [ - [transforms.Pad(2), transforms.RandomCrop(28)], - [lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)], - ], - ) - def test_ctor(self, transform_cls, trfms): - c = transform_cls(trfms) - inpt = torch.rand(1, 3, 32, 32) - output = c(inpt) - assert isinstance(output, torch.Tensor) - assert output.ndim == 4 - - -class TestRandomChoice: - def test_assertions(self): - with pytest.warns(UserWarning, match="Argument p is deprecated and will be removed"): - transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1, 2]) - - with pytest.raises(ValueError, match="The number of probabilities doesn't match the number of transforms"): - transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], probabilities=[1]) - - -class TestRandomIoUCrop: - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) - def test__get_params(self, device, options, mocker): - image = mocker.MagicMock(spec=features.Image) - image.num_channels = 3 - image.image_size = (24, 32) - bboxes = features.BoundingBox( - torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), - format="XYXY", - image_size=image.image_size, - device=device, - ) - sample = [image, bboxes] - - transform = transforms.RandomIoUCrop(sampler_options=options) - - n_samples = 5 - for _ in range(n_samples): - - params = transform._get_params(sample) - - if options == [2.0]: - assert len(params) == 0 - return - - assert len(params["is_within_crop_area"]) > 0 - assert params["is_within_crop_area"].dtype == torch.bool - - orig_h = image.image_size[0] - orig_w = image.image_size[1] - assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) - assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) - - left, top = params["left"], params["top"] - new_h, new_w = params["height"], params["width"] - ious = box_iou( - bboxes, - torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device), - ) - assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}" - - def test__transform_empty_params(self, mocker): - transform = transforms.RandomIoUCrop(sampler_options=[2.0]) - image = features.Image(torch.rand(1, 3, 4, 4)) - bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) - label = features.Label(torch.tensor([1])) - sample = [image, bboxes, label] - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock(return_value={}) - output = transform(sample) - torch.testing.assert_close(output, sample) - - def test_forward_assertion(self): - transform = transforms.RandomIoUCrop() - with pytest.raises( - TypeError, - match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", - ): - transform(torch.tensor(0)) - - def test__transform(self, mocker): - transform = transforms.RandomIoUCrop() - - image = features.Image(torch.rand(3, 32, 24)) - bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) - label = features.Label(torch.randint(0, 10, size=(6,))) - ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) - masks = make_detection_mask((32, 24), num_objects=6) - - sample = [image, bboxes, label, ohe_label, masks] - - fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) - is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) - - params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) - transform._get_params = mocker.MagicMock(return_value=params) - output = transform(sample) - - assert fn.call_count == 3 - - expected_calls = [ - mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - ] - - fn.assert_has_calls(expected_calls) - - expected_within_targets = sum(is_within_crop_area) - - # check number of bboxes vs number of labels: - output_bboxes = output[1] - assert isinstance(output_bboxes, features.BoundingBox) - assert len(output_bboxes) == expected_within_targets - - # check labels - output_label = output[2] - assert isinstance(output_label, features.Label) - assert len(output_label) == expected_within_targets - torch.testing.assert_close(output_label, label[is_within_crop_area]) - - output_ohe_label = output[3] - assert isinstance(output_ohe_label, features.OneHotLabel) - torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) - - output_masks = output[4] - assert isinstance(output_masks, features.Mask) - assert len(output_masks) == expected_within_targets - - -class TestScaleJitter: - def test__get_params(self, mocker): - image_size = (24, 32) - target_size = (16, 12) - scale_range = (0.5, 1.5) - - transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) - sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) - - n_samples = 5 - for _ in range(n_samples): - - params = transform._get_params(sample) - - assert "size" in params - size = params["size"] - - assert isinstance(size, tuple) and len(size) == 2 - height, width = size - - r_min = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[0] - r_max = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[1] - - assert int(image_size[0] * r_min) <= height <= int(image_size[0] * r_max) - assert int(image_size[1] * r_min) <= width <= int(image_size[1] * r_max) - - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() - antialias_sentinel = mocker.MagicMock() - - transform = transforms.ScaleJitter( - target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - - -class TestRandomShortestSize: - def test__get_params(self, mocker): - image_size = (3, 10) - min_size = [5, 9] - max_size = 20 - - transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) - - sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) - params = transform._get_params(sample) - - assert "size" in params - size = params["size"] - - assert isinstance(size, tuple) and len(size) == 2 - - longer = max(size) - assert longer <= max_size - - shorter = min(size) - if longer == max_size: - assert shorter <= max_size - else: - assert shorter in min_size - - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomShortestSize( - min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.RandomShortestSize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - - -class TestSimpleCopyPaste: - def create_fake_image(self, mocker, image_type): - if image_type == PIL.Image.Image: - return PIL.Image.new("RGB", (32, 32), 123) - return mocker.MagicMock(spec=image_type) - - def test__extract_image_targets_assertion(self, mocker): - transform = transforms.SimpleCopyPaste() - - flat_sample = [ - # images, batch size = 2 - self.create_fake_image(mocker, features.Image), - # labels, bboxes, masks - mocker.MagicMock(spec=features.Label), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), - # labels, bboxes, masks - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), - ] - - with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"): - transform._extract_image_targets(flat_sample) - - @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) - @pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) - def test__extract_image_targets(self, image_type, label_type, mocker): - transform = transforms.SimpleCopyPaste() - - flat_sample = [ - # images, batch size = 2 - self.create_fake_image(mocker, image_type), - self.create_fake_image(mocker, image_type), - # labels, bboxes, masks - mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), - # labels, bboxes, masks - mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), - ] - - images, targets = transform._extract_image_targets(flat_sample) - - assert len(images) == len(targets) == 2 - if image_type == PIL.Image.Image: - torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0])) - torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1])) - else: - assert images[0] == flat_sample[0] - assert images[1] == flat_sample[1] - - for target in targets: - for key, type_ in [ - ("boxes", features.BoundingBox), - ("masks", features.Mask), - ("labels", label_type), - ]: - assert key in target - assert isinstance(target[key], type_) - assert target[key] in flat_sample - - @pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) - def test__copy_paste(self, label_type): - image = 2 * torch.ones(3, 32, 32) - masks = torch.zeros(2, 32, 32) - masks[0, 3:9, 2:8] = 1 - masks[1, 20:30, 20:30] = 1 - labels = torch.tensor([1, 2]) - blending = True - resize_interpolation = InterpolationMode.BILINEAR - antialias = None - if label_type == features.OneHotLabel: - labels = torch.nn.functional.one_hot(labels, num_classes=5) - target = { - "boxes": features.BoundingBox( - torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) - ), - "masks": features.Mask(masks), - "labels": label_type(labels), - } - - paste_image = 10 * torch.ones(3, 32, 32) - paste_masks = torch.zeros(2, 32, 32) - paste_masks[0, 13:19, 12:18] = 1 - paste_masks[1, 15:19, 1:8] = 1 - paste_labels = torch.tensor([3, 4]) - if label_type == features.OneHotLabel: - paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) - paste_target = { - "boxes": features.BoundingBox( - torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) - ), - "masks": features.Mask(paste_masks), - "labels": label_type(paste_labels), - } - - transform = transforms.SimpleCopyPaste() - random_selection = torch.tensor([0, 1]) - output_image, output_target = transform._copy_paste( - image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias - ) - - assert output_image.unique().tolist() == [2, 10] - assert output_target["boxes"].shape == (4, 4) - torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"]) - torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) - - expected_labels = torch.tensor([1, 2, 3, 4]) - if label_type == features.OneHotLabel: - expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) - torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) - - assert output_target["masks"].shape == (4, 32, 32) - torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) - torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) - - -class TestFixedSizeCrop: - def test__get_params(self, mocker): - crop_size = (7, 7) - batch_shape = (10,) - image_size = (11, 5) - - transform = transforms.FixedSizeCrop(size=crop_size) - - sample = dict( - image=make_image(size=image_size, color_space=features.ColorSpace.RGB), - bounding_boxes=make_bounding_box( - format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape - ), - ) - params = transform._get_params(sample) - - assert params["needs_crop"] - assert params["height"] <= crop_size[0] - assert params["width"] <= crop_size[1] - - assert ( - isinstance(params["is_valid"], torch.Tensor) - and params["is_valid"].dtype is torch.bool - and params["is_valid"].shape == batch_shape - ) - - assert params["needs_pad"] - assert any(pad > 0 for pad in params["padding"]) - - @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) - def test__transform(self, mocker, needs): - fill_sentinel = 12 - padding_mode_sentinel = mocker.MagicMock() - - transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) - transform._transformed_types = (mocker.MagicMock,) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - needs_crop, needs_pad = needs - top_sentinel = mocker.MagicMock() - left_sentinel = mocker.MagicMock() - height_sentinel = mocker.MagicMock() - width_sentinel = mocker.MagicMock() - is_valid = mocker.MagicMock() if needs_crop else None - padding_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=needs_crop, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - is_valid=is_valid, - padding=padding_sentinel, - needs_pad=needs_pad, - ), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") - mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") - transform(inpt_sentinel) - - if needs_crop: - mock_crop.assert_called_once_with( - inpt_sentinel, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - ) - else: - mock_crop.assert_not_called() - - if needs_pad: - # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use - # `MagicMock.assert_called_once_with` and have to perform the checks manually - mock_pad.assert_called_once() - args, kwargs = mock_pad.call_args - if not needs_crop: - assert args[0] is inpt_sentinel - assert args[1] is padding_sentinel - fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel) - assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) - else: - mock_pad.assert_not_called() - - def test__transform_culling(self, mocker): - batch_size = 10 - image_size = (10, 10) - - is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=True, - top=0, - left=0, - height=image_size[0], - width=image_size[1], - is_valid=is_valid, - needs_pad=False, - ), - ) - - bounding_boxes = make_bounding_box( - format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) - ) - masks = make_detection_mask(size=image_size, extra_dims=(batch_size,)) - labels = make_label(extra_dims=(batch_size,)) - - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - output = transform( - dict( - bounding_boxes=bounding_boxes, - masks=masks, - labels=labels, - ) - ) - - assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) - assert_equal(output["masks"], masks[is_valid]) - assert_equal(output["labels"], labels[is_valid]) - - def test__transform_bounding_box_clamping(self, mocker): - batch_size = 3 - image_size = (10, 10) - - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=True, - top=0, - left=0, - height=image_size[0], - width=image_size[1], - is_valid=torch.full((batch_size,), fill_value=True), - needs_pad=False, - ), - ) - - bounding_box = make_bounding_box( - format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) - ) - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") - - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - transform(bounding_box) - - mock.assert_called_once() - - -class TestLinearTransformation: - def test_assertions(self): - with pytest.raises(ValueError, match="transformation_matrix should be square"): - transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5)) - - with pytest.raises(ValueError, match="mean_vector should have the same length"): - transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5)) - - @pytest.mark.parametrize( - "inpt", - [ - 122 * torch.ones(1, 3, 8, 8), - 122.0 * torch.ones(1, 3, 8, 8), - features.Image(122 * torch.ones(1, 3, 8, 8)), - PIL.Image.new("RGB", (8, 8), (122, 122, 122)), - ], - ) - def test__transform(self, inpt): - - v = 121 * torch.ones(3 * 8 * 8) - m = torch.ones(3 * 8 * 8, 3 * 8 * 8) - transform = transforms.LinearTransformation(m, v) - - if isinstance(inpt, PIL.Image.Image): - with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"): - transform(inpt) - else: - output = transform(inpt) - assert isinstance(output, torch.Tensor) - assert output.unique() == 3 * 8 * 8 - assert output.dtype == inpt.dtype - - -class TestLabelToOneHot: - def test__transform(self): - categories = ["apple", "pear", "pineapple"] - labels = features.Label(torch.tensor([0, 1, 2, 1]), categories=categories) - transform = transforms.LabelToOneHot() - ohe_labels = transform(labels) - assert isinstance(ohe_labels, features.OneHotLabel) - assert ohe_labels.shape == (4, 3) - assert ohe_labels.categories == labels.categories == categories - - -class TestRandomResize: - def test__get_params(self): - min_size = 3 - max_size = 6 - - transform = transforms.RandomResize(min_size=min_size, max_size=max_size) - - for _ in range(10): - params = transform._get_params(None) - - assert isinstance(params["size"], list) and len(params["size"]) == 1 - size = params["size"][0] - - assert min_size <= size < max_size - - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock() - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomResize( - min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.RandomResize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_resize = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") - transform(inpt_sentinel) - - mock_resize.assert_called_with( - inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py deleted file mode 100644 index c8debe1e293..00000000000 --- a/test/test_prototype_transforms_consistency.py +++ /dev/null @@ -1,1097 +0,0 @@ -import enum -import inspect -import random -from collections import defaultdict -from importlib.machinery import SourceFileLoader -from pathlib import Path - -import numpy as np -import PIL.Image -import pytest - -import torch -from prototype_common_utils import ( - ArgsKwargs, - assert_equal, - make_bounding_box, - make_detection_mask, - make_image, - make_images, - make_label, - make_segmentation_mask, -) -from torchvision import transforms as legacy_transforms -from torchvision._utils import sequence_to_str -from torchvision.prototype import features, transforms as prototype_transforms -from torchvision.prototype.transforms import functional as F -from torchvision.prototype.transforms._utils import query_chw -from torchvision.prototype.transforms.functional import to_image_pil - -DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) - - -class ConsistencyConfig: - def __init__( - self, - prototype_cls, - legacy_cls, - # If no args_kwargs is passed, only the signature will be checked - args_kwargs=(), - make_images_kwargs=None, - supports_pil=True, - removed_params=(), - ): - self.prototype_cls = prototype_cls - self.legacy_cls = legacy_cls - self.args_kwargs = args_kwargs - self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS - self.supports_pil = supports_pil - self.removed_params = removed_params - - -# These are here since both the prototype and legacy transform need to be constructed with the same random parameters -LINEAR_TRANSFORMATION_MEAN = torch.rand(36) -LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) - -CONSISTENCY_CONFIGS = [ - ConsistencyConfig( - prototype_transforms.Normalize, - legacy_transforms.Normalize, - [ - ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ], - supports_pil=False, - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), - ), - ConsistencyConfig( - prototype_transforms.Resize, - legacy_transforms.Resize, - [ - ArgsKwargs(32), - ArgsKwargs((32, 29)), - ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), - # FIXME: these are currently failing, since the new transform only supports the enum. The int input is - # already deprecated and scheduled to be removed in 0.15. Should we support ints on the prototype - # transform? I guess it depends if we roll out before 0.15 or not. - # ArgsKwargs((30, 27), interpolation=0), - # ArgsKwargs((35, 29), interpolation=2), - # ArgsKwargs((34, 25), interpolation=3), - ArgsKwargs(31, max_size=32), - ArgsKwargs(30, max_size=100), - ArgsKwargs((29, 32), antialias=False), - ArgsKwargs((28, 31), antialias=True), - ], - ), - ConsistencyConfig( - prototype_transforms.CenterCrop, - legacy_transforms.CenterCrop, - [ - ArgsKwargs(18), - ArgsKwargs((18, 13)), - ], - ), - ConsistencyConfig( - prototype_transforms.FiveCrop, - legacy_transforms.FiveCrop, - [ - ArgsKwargs(18), - ArgsKwargs((18, 13)), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), - ), - ConsistencyConfig( - prototype_transforms.TenCrop, - legacy_transforms.TenCrop, - [ - ArgsKwargs(18), - ArgsKwargs((18, 13)), - ArgsKwargs(18, vertical_flip=True), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), - ), - ConsistencyConfig( - prototype_transforms.Pad, - legacy_transforms.Pad, - [ - ArgsKwargs(3), - ArgsKwargs([3]), - ArgsKwargs([2, 3]), - ArgsKwargs([3, 2, 1, 4]), - ArgsKwargs(5, fill=1, padding_mode="constant"), - ArgsKwargs(5, padding_mode="edge"), - ArgsKwargs(5, padding_mode="reflect"), - ArgsKwargs(5, padding_mode="symmetric"), - ], - ), - ConsistencyConfig( - prototype_transforms.LinearTransformation, - legacy_transforms.LinearTransformation, - [ - ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN), - ], - # Make sure that the product of the height, width and number of channels matches the number of elements in - # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB] - ), - supports_pil=False, - ), - ConsistencyConfig( - prototype_transforms.Grayscale, - legacy_transforms.Grayscale, - [ - ArgsKwargs(num_output_channels=1), - ArgsKwargs(num_output_channels=3), - ], - make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY] - ), - ), - ConsistencyConfig( - prototype_transforms.ConvertImageDtype, - legacy_transforms.ConvertImageDtype, - [ - ArgsKwargs(torch.float16), - ArgsKwargs(torch.bfloat16), - ArgsKwargs(torch.float32), - ArgsKwargs(torch.float64), - ArgsKwargs(torch.uint8), - ], - supports_pil=False, - ), - ConsistencyConfig( - prototype_transforms.ToPILImage, - legacy_transforms.ToPILImage, - [ArgsKwargs()], - make_images_kwargs=dict( - color_spaces=[ - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, - ], - extra_dims=[()], - ), - supports_pil=False, - ), - ConsistencyConfig( - prototype_transforms.Lambda, - legacy_transforms.Lambda, - [ - ArgsKwargs(lambda image: image / 2), - ], - # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL - # images given that the transform does nothing but call it anyway. - supports_pil=False, - ), - ConsistencyConfig( - prototype_transforms.RandomHorizontalFlip, - legacy_transforms.RandomHorizontalFlip, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomVerticalFlip, - legacy_transforms.RandomVerticalFlip, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomEqualize, - legacy_transforms.RandomEqualize, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), - ), - ConsistencyConfig( - prototype_transforms.RandomInvert, - legacy_transforms.RandomInvert, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomPosterize, - legacy_transforms.RandomPosterize, - [ - ArgsKwargs(p=0, bits=5), - ArgsKwargs(p=1, bits=1), - ArgsKwargs(p=1, bits=3), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), - ), - ConsistencyConfig( - prototype_transforms.RandomSolarize, - legacy_transforms.RandomSolarize, - [ - ArgsKwargs(p=0, threshold=0.5), - ArgsKwargs(p=1, threshold=0.3), - ArgsKwargs(p=1, threshold=0.99), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomAutocontrast, - legacy_transforms.RandomAutocontrast, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomAdjustSharpness, - legacy_transforms.RandomAdjustSharpness, - [ - ArgsKwargs(p=0, sharpness_factor=0.5), - ArgsKwargs(p=1, sharpness_factor=0.3), - ArgsKwargs(p=1, sharpness_factor=0.99), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomGrayscale, - legacy_transforms.RandomGrayscale, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomResizedCrop, - legacy_transforms.RandomResizedCrop, - [ - ArgsKwargs(16), - ArgsKwargs(17, scale=(0.3, 0.7)), - ArgsKwargs(25, ratio=(0.5, 1.5)), - ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), - ArgsKwargs((29, 32), antialias=False), - ArgsKwargs((28, 31), antialias=True), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomErasing, - legacy_transforms.RandomErasing, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ArgsKwargs(p=1, scale=(0.3, 0.7)), - ArgsKwargs(p=1, ratio=(0.5, 1.5)), - ArgsKwargs(p=1, value=1), - ArgsKwargs(p=1, value=(1, 2, 3)), - ArgsKwargs(p=1, value="random"), - ], - supports_pil=False, - ), - ConsistencyConfig( - prototype_transforms.ColorJitter, - legacy_transforms.ColorJitter, - [ - ArgsKwargs(), - ArgsKwargs(brightness=0.1), - ArgsKwargs(brightness=(0.2, 0.3)), - ArgsKwargs(contrast=0.4), - ArgsKwargs(contrast=(0.5, 0.6)), - ArgsKwargs(saturation=0.7), - ArgsKwargs(saturation=(0.8, 0.9)), - ArgsKwargs(hue=0.3), - ArgsKwargs(hue=(-0.1, 0.2)), - ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), - ], - ), - ConsistencyConfig( - prototype_transforms.ElasticTransform, - legacy_transforms.ElasticTransform, - [ - ArgsKwargs(), - ArgsKwargs(alpha=20.0), - ArgsKwargs(alpha=(15.3, 27.2)), - ArgsKwargs(sigma=3.0), - ArgsKwargs(sigma=(2.5, 3.9)), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), - ArgsKwargs(fill=1), - ], - # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]), - ), - ConsistencyConfig( - prototype_transforms.GaussianBlur, - legacy_transforms.GaussianBlur, - [ - ArgsKwargs(kernel_size=3), - ArgsKwargs(kernel_size=(1, 5)), - ArgsKwargs(kernel_size=3, sigma=0.7), - ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomAffine, - legacy_transforms.RandomAffine, - [ - ArgsKwargs(degrees=30.0), - ArgsKwargs(degrees=(-20.0, 10.0)), - ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)), - ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)), - ArgsKwargs(degrees=0.0, shear=13), - ArgsKwargs(degrees=0.0, shear=(8, 17)), - ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), - ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), - ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs(degrees=30.0, fill=1), - ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), - ArgsKwargs(degrees=30.0, center=(0, 0)), - ], - removed_params=["fillcolor", "resample"], - ), - ConsistencyConfig( - prototype_transforms.RandomCrop, - legacy_transforms.RandomCrop, - [ - ArgsKwargs(12), - ArgsKwargs((15, 17)), - ArgsKwargs(11, padding=1), - ArgsKwargs((8, 13), padding=(2, 3)), - ArgsKwargs((14, 9), padding=(0, 2, 1, 0)), - ArgsKwargs(36, pad_if_needed=True), - ArgsKwargs((7, 8), fill=1), - ArgsKwargs(5, fill=(1, 2, 3)), - ArgsKwargs(12), - ArgsKwargs(15, padding=2, padding_mode="edge"), - ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"), - ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]), - ), - ConsistencyConfig( - prototype_transforms.RandomPerspective, - legacy_transforms.RandomPerspective, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ArgsKwargs(p=1, distortion_scale=0.3), - ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs(p=1, distortion_scale=0.1, fill=1), - ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), - ], - ), - ConsistencyConfig( - prototype_transforms.RandomRotation, - legacy_transforms.RandomRotation, - [ - ArgsKwargs(degrees=30.0), - ArgsKwargs(degrees=(-20.0, 10.0)), - ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), - ArgsKwargs(degrees=30.0, expand=True), - ArgsKwargs(degrees=30.0, center=(0, 0)), - ArgsKwargs(degrees=30.0, fill=1), - ArgsKwargs(degrees=30.0, fill=(1, 2, 3)), - ], - removed_params=["resample"], - ), - ConsistencyConfig( - prototype_transforms.PILToTensor, - legacy_transforms.PILToTensor, - ), - ConsistencyConfig( - prototype_transforms.ToTensor, - legacy_transforms.ToTensor, - ), - ConsistencyConfig( - prototype_transforms.Compose, - legacy_transforms.Compose, - ), - ConsistencyConfig( - prototype_transforms.RandomApply, - legacy_transforms.RandomApply, - ), - ConsistencyConfig( - prototype_transforms.RandomChoice, - legacy_transforms.RandomChoice, - ), - ConsistencyConfig( - prototype_transforms.RandomOrder, - legacy_transforms.RandomOrder, - ), - ConsistencyConfig( - prototype_transforms.AugMix, - legacy_transforms.AugMix, - ), - ConsistencyConfig( - prototype_transforms.AutoAugment, - legacy_transforms.AutoAugment, - ), - ConsistencyConfig( - prototype_transforms.RandAugment, - legacy_transforms.RandAugment, - ), - ConsistencyConfig( - prototype_transforms.TrivialAugmentWide, - legacy_transforms.TrivialAugmentWide, - ), -] - - -def test_automatic_coverage(): - available = { - name - for name, obj in legacy_transforms.__dict__.items() - if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum) - } - - checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS} - - missing = available - checked - if missing: - raise AssertionError( - f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} " - f"are not checked for consistency although a legacy counterpart exists." - ) - - -@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__) -def test_signature_consistency(config): - legacy_params = dict(inspect.signature(config.legacy_cls).parameters) - prototype_params = dict(inspect.signature(config.prototype_cls).parameters) - - for param in config.removed_params: - legacy_params.pop(param, None) - - missing = legacy_params.keys() - prototype_params.keys() - if missing: - raise AssertionError( - f"The prototype transform does not support the parameters " - f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. " - f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on " - f"the `ConsistencyConfig`." - ) - - extra = prototype_params.keys() - legacy_params.keys() - extra_without_default = { - param - for param in extra - if prototype_params[param].default is inspect.Parameter.empty - and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} - } - if extra_without_default: - raise AssertionError( - f"The prototype transform requires the parameters " - f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does " - f"not. Please add a default value." - ) - - legacy_kinds = {name: param.kind for name, param in legacy_params.items()} - prototype_kinds = {name: prototype_params[name].kind for name in legacy_kinds.keys()} - assert prototype_kinds == legacy_kinds - - -def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True): - if images is None: - images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS) - - for image in images: - image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" - - image_tensor = torch.Tensor(image) - - try: - torch.manual_seed(0) - output_legacy_tensor = legacy_transform(image_tensor) - except Exception as exc: - raise pytest.UsageError( - f"Transforming a tensor image {image_repr} failed in the legacy transform with the " - f"error above. This means that you need to specify the parameters passed to `make_images` through the " - "`make_images_kwargs` of the `ConsistencyConfig`." - ) from exc - - try: - torch.manual_seed(0) - output_prototype_tensor = prototype_transform(image_tensor) - except Exception as exc: - raise AssertionError( - f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with " - f"the error above. This means there is a consistency bug either in `_get_params` or in the " - f"`is_simple_tensor` path in `_transform`." - ) from exc - - assert_equal( - output_prototype_tensor, - output_legacy_tensor, - msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}", - ) - - try: - torch.manual_seed(0) - output_prototype_image = prototype_transform(image) - except Exception as exc: - raise AssertionError( - f"Transforming a feature image with shape {image_repr} failed in the prototype transform with " - f"the error above. This means there is a consistency bug either in `_get_params` or in the " - f"`features.Image` path in `_transform`." - ) from exc - - assert_equal( - output_prototype_image, - output_prototype_tensor, - msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", - ) - - if image.ndim == 3 and supports_pil: - image_pil = to_image_pil(image) - - try: - torch.manual_seed(0) - output_legacy_pil = legacy_transform(image_pil) - except Exception as exc: - raise pytest.UsageError( - f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the " - f"error above. If this transform does not support PIL images, set `supports_pil=False` on the " - "`ConsistencyConfig`. " - ) from exc - - try: - torch.manual_seed(0) - output_prototype_pil = prototype_transform(image_pil) - except Exception as exc: - raise AssertionError( - f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with " - f"the error above. This means there is a consistency bug either in `_get_params` or in the " - f"`PIL.Image.Image` path in `_transform`." - ) from exc - - assert_equal( - output_prototype_pil, - output_legacy_pil, - msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}", - ) - - -@pytest.mark.parametrize( - ("config", "args_kwargs"), - [ - pytest.param(config, args_kwargs, id=f"{config.legacy_cls.__name__}({args_kwargs})") - for config in CONSISTENCY_CONFIGS - for args_kwargs in config.args_kwargs - ], -) -def test_call_consistency(config, args_kwargs): - args, kwargs = args_kwargs - - try: - legacy_transform = config.legacy_cls(*args, **kwargs) - except Exception as exc: - raise pytest.UsageError( - f"Initializing the legacy transform failed with the error above. " - f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`." - ) from exc - - try: - prototype_transform = config.prototype_cls(*args, **kwargs) - except Exception as exc: - raise AssertionError( - "Initializing the prototype transform failed with the error above. " - "This means there is a consistency bug in the constructor." - ) from exc - - check_call_consistency( - prototype_transform, - legacy_transform, - images=make_images(**config.make_images_kwargs), - supports_pil=config.supports_pil, - ) - - -class TestContainerTransforms: - """ - Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for - consistency automatically tests the wrapped transforms consistency. - - Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones - that were already tested for consistency above. - """ - - def test_compose(self): - prototype_transform = prototype_transforms.Compose( - [ - prototype_transforms.Resize(256), - prototype_transforms.CenterCrop(224), - ] - ) - legacy_transform = legacy_transforms.Compose( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ] - ) - - check_call_consistency(prototype_transform, legacy_transform) - - @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) - def test_random_apply(self, p): - prototype_transform = prototype_transforms.RandomApply( - [ - prototype_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=p, - ) - legacy_transform = legacy_transforms.RandomApply( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=p, - ) - - check_call_consistency(prototype_transform, legacy_transform) - - # We can't test other values for `p` since the random parameter generation is different - @pytest.mark.parametrize("p", [(0, 1), (1, 0)]) - def test_random_choice(self, p): - prototype_transform = prototype_transforms.RandomChoice( - [ - prototype_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=p, - ) - legacy_transform = legacy_transforms.RandomChoice( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=p, - ) - - check_call_consistency(prototype_transform, legacy_transform) - - -class TestToTensorTransforms: - def test_pil_to_tensor(self): - prototype_transform = prototype_transforms.PILToTensor() - legacy_transform = legacy_transforms.PILToTensor() - - for image in make_images(extra_dims=[()]): - image_pil = to_image_pil(image) - - assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) - - def test_to_tensor(self): - prototype_transform = prototype_transforms.ToTensor() - legacy_transform = legacy_transforms.ToTensor() - - for image in make_images(extra_dims=[()]): - image_pil = to_image_pil(image) - image_numpy = np.array(image_pil) - - assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) - assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy)) - - -class TestAATransforms: - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], - ) - def test_randaug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) - t = prototype_transforms.RandAugment(interpolation=interpolation, num_ops=1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - # Stable API, if signed there is another random call - if t._AUGMENTATION_SPACE[keys[i]][1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for i in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], - ) - def test_trivial_aug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) - t = prototype_transforms.TrivialAugmentWide(interpolation=interpolation) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for _ in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], - ) - def test_augmix(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) - t = prototype_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t._sample_dirichlet = lambda t: t.softmax(dim=-1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - expected_output = t_ref(inpt) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], - ) - def test_aa(self, inpt, interpolation): - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") - t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = prototype_transforms.AutoAugment(aa_policy, interpolation=interpolation) - - torch.manual_seed(12) - expected_output = t_ref(inpt) - - torch.manual_seed(12) - output = t(inpt) - - assert_equal(expected_output, output) - - -def import_transforms_from_references(reference): - ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py" - return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() - - -det_transforms = import_transforms_from_references("detection") - - -class TestRefDetTransforms: - def make_datapoints(self, with_mask=True): - size = (600, 800) - num_objects = 22 - - pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) - target = { - "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - } - if with_mask: - target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) - - yield (pil_image, target) - - tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) - target = { - "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - } - if with_mask: - target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) - - yield (tensor_image, target) - - feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) - target = { - "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - } - if with_mask: - target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) - - yield (feature_image, target) - - @pytest.mark.parametrize( - "t_ref, t, data_kwargs", - [ - (det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}), - (det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}), - (det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}), - (det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}), - ( - det_transforms.FixedSizeCrop((1024, 1024), fill=0), - prototype_transforms.FixedSizeCrop((1024, 1024), fill=0), - {}, - ), - ( - det_transforms.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - prototype_transforms.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - {}, - ), - ], - ) - def test_transform(self, t_ref, t, data_kwargs): - for dp in self.make_datapoints(**data_kwargs): - - # We should use prototype transform first as reference transform performs inplace target update - torch.manual_seed(12) - output = t(dp) - - torch.manual_seed(12) - expected_output = t_ref(*dp) - - assert_equal(expected_output, output) - - -seg_transforms = import_transforms_from_references("segmentation") - - -# We need this transform for two reasons: -# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name -# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` -# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. -class PadIfSmaller(prototype_transforms.Transform): - def __init__(self, size, fill=0): - super().__init__() - self.size = size - self.fill = prototype_transforms._geometry._setup_fill_arg(fill) - - def _get_params(self, sample): - _, height, width = query_chw(sample) - padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] - needs_padding = any(padding) - return dict(padding=padding, needs_padding=needs_padding) - - def _transform(self, inpt, params): - if not params["needs_padding"]: - return inpt - - fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) - - return F.pad(inpt, padding=params["padding"], fill=fill) - - -class TestRefSegTransforms: - def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): - size = (256, 460) - num_categories = 21 - - conv_fns = [] - if supports_pil: - conv_fns.append(to_image_pil) - conv_fns.extend([torch.Tensor, lambda x: x]) - - for conv_fn in conv_fns: - feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) - feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) - - dp = (conv_fn(feature_image), feature_mask) - dp_ref = ( - to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), - to_image_pil(feature_mask), - ) - - yield dp, dp_ref - - def set_seed(self, seed=12): - torch.manual_seed(seed) - random.seed(seed) - - def check(self, t, t_ref, data_kwargs=None): - for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): - - self.set_seed() - output = t(dp) - - self.set_seed() - expected_output = t_ref(*dp_ref) - - assert_equal(output, expected_output) - - @pytest.mark.parametrize( - ("t_ref", "t", "data_kwargs"), - [ - ( - seg_transforms.RandomHorizontalFlip(flip_prob=1.0), - prototype_transforms.RandomHorizontalFlip(p=1.0), - dict(), - ), - ( - seg_transforms.RandomHorizontalFlip(flip_prob=0.0), - prototype_transforms.RandomHorizontalFlip(p=0.0), - dict(), - ), - ( - seg_transforms.RandomCrop(size=480), - prototype_transforms.Compose( - [ - PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), - prototype_transforms.RandomCrop(size=480), - ] - ), - dict(), - ), - ( - seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - dict(supports_pil=False, image_dtype=torch.float), - ), - ], - ) - def test_common(self, t_ref, t, data_kwargs): - self.check(t, t_ref, data_kwargs) - - def check_resize(self, mocker, t_ref, t): - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") - mock_ref = mocker.patch("torchvision.transforms.functional.resize") - - for dp, dp_ref in self.make_datapoints(): - mock.reset_mock() - mock_ref.reset_mock() - - self.set_seed() - t(dp) - assert mock.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) - ) - - self.set_seed() - t_ref(*dp_ref) - assert mock_ref.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) - ) - - for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): - assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] - - def test_random_resize_train(self, mocker): - base_size = 520 - min_size = base_size // 2 - max_size = base_size * 2 - - randint = torch.randint - - def patched_randint(a, b, *other_args, **kwargs): - if kwargs or len(other_args) > 1 or other_args[0] != (): - return randint(a, b, *other_args, **kwargs) - - return random.randint(a, b) - - # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported - # normally - t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) - mocker.patch( - "torchvision.prototype.transforms._geometry.torch.randint", - new=patched_randint, - ) - - t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) - - self.check_resize(mocker, t_ref, t) - - def test_random_resize_eval(self, mocker): - torch.manual_seed(0) - base_size = 520 - - t = prototype_transforms.Resize(size=base_size, antialias=True) - - t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) - - self.check_resize(mocker, t_ref, t) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py deleted file mode 100644 index b2c830d5d5f..00000000000 --- a/test/test_prototype_transforms_functional.py +++ /dev/null @@ -1,956 +0,0 @@ -import math -import os - -import numpy as np -import PIL.Image -import pytest - -import torch -from common_utils import cache, cpu_and_gpu, needs_cuda -from prototype_common_utils import assert_close, make_bounding_boxes, make_image -from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS -from prototype_transforms_kernel_infos import KERNEL_INFOS -from torch.utils._pytree import tree_map -from torchvision.prototype import features -from torchvision.prototype.transforms import functional as F -from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding -from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box -from torchvision.transforms.functional import _get_perspective_coeffs - - -@cache -def script(fn): - try: - return torch.jit.script(fn) - except Exception as error: - raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error - - -@pytest.fixture(autouse=True) -def maybe_skip(request): - # In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist - try: - callspec = request.node.callspec - except AttributeError: - return - - try: - info = callspec.params["info"] - args_kwargs = callspec.params["args_kwargs"] - except KeyError: - return - - info.maybe_skip( - test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu") - ) - - -class TestKernels: - sample_inputs = pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}") - for info in KERNEL_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs_fn()) - ], - ) - - @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_scripted_vs_eager(self, info, args_kwargs, device): - kernel_eager = info.kernel - kernel_scripted = script(kernel_eager) - - args, kwargs = args_kwargs.load(device) - - actual = kernel_scripted(*args, **kwargs) - expected = kernel_eager(*args, **kwargs) - - assert_close(actual, expected, **info.closeness_kwargs) - - def _unbatch(self, batch, *, data_dims): - if isinstance(batch, torch.Tensor): - batched_tensor = batch - metadata = () - else: - batched_tensor, *metadata = batch - - if batched_tensor.ndim == data_dims: - return batch - - return [ - self._unbatch(unbatched, data_dims=data_dims) - for unbatched in ( - batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] - ) - ] - - @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_batched_vs_single(self, info, args_kwargs, device): - (batched_input, *other_args), kwargs = args_kwargs.load(device) - - feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input) - # This dictionary contains the number of rightmost dimensions that contain the actual data. - # Everything to the left is considered a batch dimension. - data_dims = { - features.Image: 3, - features.BoundingBox: 1, - # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks - # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one - # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as - # common ground. - features.Mask: 2, - }.get(feature_type) - if data_dims is None: - raise pytest.UsageError( - f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}." - ) from None - elif batched_input.ndim <= data_dims: - pytest.skip("Input is not batched.") - elif not all(batched_input.shape[:-data_dims]): - pytest.skip("Input has a degenerate batch shape.") - - batched_output = info.kernel(batched_input, *other_args, **kwargs) - actual = self._unbatch(batched_output, data_dims=data_dims) - - single_inputs = self._unbatch(batched_input, data_dims=data_dims) - expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs) - - assert_close(actual, expected, **info.closeness_kwargs) - - @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_no_inplace(self, info, args_kwargs, device): - (input, *other_args), kwargs = args_kwargs.load(device) - - if input.numel() == 0: - pytest.skip("The input has a degenerate shape.") - - input_version = input._version - info.kernel(input, *other_args, **kwargs) - - assert input._version == input_version - - @sample_inputs - @needs_cuda - def test_cuda_vs_cpu(self, info, args_kwargs): - (input_cpu, *other_args), kwargs = args_kwargs.load("cpu") - input_cuda = input_cpu.to("cuda") - - output_cpu = info.kernel(input_cpu, *other_args, **kwargs) - output_cuda = info.kernel(input_cuda, *other_args, **kwargs) - - assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs) - - @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_dtype_and_device_consistency(self, info, args_kwargs, device): - (input, *other_args), kwargs = args_kwargs.load(device) - - output = info.kernel(input, *other_args, **kwargs) - # Most kernels just return a tensor, but some also return some additional metadata - if not isinstance(output, torch.Tensor): - output, *_ = output - - assert output.dtype == input.dtype - assert output.device == input.device - - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}") - for info in KERNEL_INFOS - for idx, args_kwargs in enumerate(info.reference_inputs_fn()) - if info.reference_fn is not None - ], - ) - def test_against_reference(self, info, args_kwargs): - args, kwargs = args_kwargs.load("cpu") - - actual = info.kernel(*args, **kwargs) - expected = info.reference_fn(*args, **kwargs) - - assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs) - - -class TestDispatchers: - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") - for info in DISPATCHER_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs(features.Image)) - if features.Image in info.kernels - ], - ) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_scripted_smoke(self, info, args_kwargs, device): - dispatcher = script(info.dispatcher) - - (image_feature, *other_args), kwargs = args_kwargs.load(device) - image_simple_tensor = torch.Tensor(image_feature) - - dispatcher(image_simple_tensor, *other_args, **kwargs) - - # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke` - # replaces this test for them. - @pytest.mark.parametrize( - "dispatcher", - [ - F.convert_color_space, - F.convert_image_dtype, - F.get_dimensions, - F.get_image_num_channels, - F.get_image_size, - F.get_spatial_size, - F.rgb_to_grayscale, - ], - ids=lambda dispatcher: dispatcher.__name__, - ) - def test_scriptable(self, dispatcher): - script(dispatcher) - - -@pytest.mark.parametrize( - ("alias", "target"), - [ - pytest.param(alias, target, id=alias.__name__) - for alias, target in [ - (F.hflip, F.horizontal_flip), - (F.vflip, F.vertical_flip), - (F.get_image_num_channels, F.get_num_channels), - (F.to_pil_image, F.to_image_pil), - (F.elastic_transform, F.elastic), - ] - ], -) -def test_alias(alias, target): - assert alias is target - - -# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in -# `prototype_transforms_kernel_infos.py` - - -def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): - rot = math.radians(angle_) - cx, cy = center_ - tx, ty = translate_ - sx, sy = [math.radians(sh_) for sh_ in shear_] - - c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - c_matrix_inv = np.linalg.inv(c_matrix) - rs_matrix = np.array( - [ - [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], - [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], - [0, 0, 1], - ] - ) - shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) - rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) - true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) - return true_matrix - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_affine_bounding_box_on_fixed_input(device): - # Check transformation against known expected output - image_size = (64, 64) - # xyxy format - in_boxes = [ - [20, 25, 35, 45], - [50, 5, 70, 22], - [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], - [1, 1, 5, 5], - ] - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device - ) - # Tested parameters - angle = 63 - scale = 0.89 - dx = 0.12 - dy = 0.23 - - # Expected bboxes computed using albumentations: - # from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate - # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox - # expected_bboxes = [] - # for in_box in in_boxes: - # n_in_box = normalize_bbox(in_box, *image_size) - # n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *image_size) - # out_box = denormalize_bbox(n_out_box, *image_size) - # expected_bboxes.append(out_box) - expected_bboxes = [ - (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695), - (54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864), - (27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844), - (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), - ] - - output_boxes = F.affine_bounding_box( - in_boxes, - in_boxes.format, - in_boxes.image_size, - angle, - (dx * image_size[1], dy * image_size[0]), - scale, - shear=(0, 0), - ) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_affine_segmentation_mask_on_fixed_input(device): - # Check transformation against known expected output and CPU/CUDA devices - - # Create a fixed input segmentation mask with 2 square masks - # in top-left, bottom-left corners - mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) - mask[0, 2:10, 2:10] = 1 - mask[0, 32 - 9 : 32 - 3, 3:9] = 2 - - # Rotate 90 degrees and scale - expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) - expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest") - expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() - - out_mask = F.affine_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) - - torch.testing.assert_close(out_mask, expected_mask) - - -@pytest.mark.parametrize("angle", range(-90, 90, 56)) -@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) -def test_correctness_rotate_bounding_box(angle, expand, center): - def _compute_expected_bbox(bbox, angle_, expand_, center_): - affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) - affine_matrix = affine_matrix[:2, :] - - height, width = bbox.image_size - bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY - ) - points = np.array( - [ - [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], - # image frame - [0.0, 0.0, 1.0], - [0.0, height, 1.0], - [width, height, 1.0], - [width, 0.0, 1.0], - ] - ) - transformed_points = np.matmul(points, affine_matrix.T) - out_bbox = [ - np.min(transformed_points[:4, 0]), - np.min(transformed_points[:4, 1]), - np.max(transformed_points[:4, 0]), - np.max(transformed_points[:4, 1]), - ] - if expand_: - tr_x = np.min(transformed_points[4:, 0]) - tr_y = np.min(transformed_points[4:, 1]) - out_bbox[0] -= tr_x - out_bbox[1] -= tr_y - out_bbox[2] -= tr_x - out_bbox[3] -= tr_y - - height = int(height - 2 * tr_y) - width = int(width - 2 * tr_x) - - out_bbox = features.BoundingBox( - out_bbox, - format=features.BoundingBoxFormat.XYXY, - image_size=(height, width), - dtype=bbox.dtype, - device=bbox.device, - ) - return ( - convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False - ), - (height, width), - ) - - image_size = (32, 38) - - for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): - bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size - - output_bboxes, output_image_size = F.rotate_bounding_box( - bboxes, - bboxes_format, - image_size=bboxes_image_size, - angle=angle, - expand=expand, - center=center, - ) - - center_ = center - if center_ is None: - center_ = [s * 0.5 for s in bboxes_image_size[::-1]] - - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_) - expected_bboxes.append(expected_bbox) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0) - torch.testing.assert_close(output_image_size, expected_image_size, atol=1, rtol=0) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 -def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): - # Check transformation against known expected output - image_size = (64, 64) - # xyxy format - in_boxes = [ - [1, 1, 5, 5], - [1, image_size[0] - 6, 5, image_size[0] - 2], - [image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2], - [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], - ] - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device - ) - # Tested parameters - angle = 45 - center = None if expand else [12, 23] - - # # Expected bboxes computed using Detectron2: - # from detectron2.data.transforms import RotationTransform, AugmentationList - # from detectron2.data.transforms import AugInput - # import cv2 - # inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32")) - # augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ]) - # out = augs(inpt) - # print(inpt.boxes) - if expand: - expected_bboxes = [ - [1.65937957, 42.67157288, 7.31623382, 48.32842712], - [41.96446609, 82.9766594, 47.62132034, 88.63351365], - [82.26955262, 42.67157288, 87.92640687, 48.32842712], - [31.35786438, 31.35786438, 59.64213562, 59.64213562], - ] - else: - expected_bboxes = [ - [-11.33452378, 12.39339828, -5.67766953, 18.05025253], - [28.97056275, 52.69848481, 34.627417, 58.35533906], - [69.27564928, 12.39339828, 74.93250353, 18.05025253], - [18.36396103, 1.07968978, 46.64823228, 29.36396103], - ] - - output_boxes, _ = F.rotate_bounding_box( - in_boxes, - in_boxes.format, - in_boxes.image_size, - angle, - expand=expand, - center=center, - ) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_rotate_segmentation_mask_on_fixed_input(device): - # Check transformation against known expected output and CPU/CUDA devices - - # Create a fixed input segmentation mask with 2 square masks - # in top-left, bottom-left corners - mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) - mask[0, 2:10, 2:10] = 1 - mask[0, 32 - 9 : 32 - 3, 3:9] = 2 - - # Rotate 90 degrees - expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) - out_mask = F.rotate_mask(mask, 90, expand=False) - torch.testing.assert_close(out_mask, expected_mask) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize( - "format", - [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], -) -@pytest.mark.parametrize( - "top, left, height, width, expected_bboxes", - [ - [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]], - [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]], - ], -) -def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes): - - # Expected bboxes computed using Albumentations: - # import numpy as np - # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox - # expected_bboxes = [] - # for in_box in in_boxes: - # n_in_box = normalize_bbox(in_box, *size) - # n_out_box = crop_bbox_by_coords( - # n_in_box, (left, top, left + width, top + height), height, width, *size - # ) - # out_box = denormalize_bbox(n_out_box, height, width) - # expected_bboxes.append(out_box) - - size = (64, 76) - # xyxy format - in_boxes = [ - [10.0, 15.0, 25.0, 35.0], - [50.0, 5.0, 70.0, 22.0], - [45.0, 46.0, 56.0, 62.0], - ] - in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device) - if format != features.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) - - output_boxes, output_image_size = F.crop_bounding_box( - in_boxes, - format, - top, - left, - size[0], - size[1], - ) - - if format != features.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - torch.testing.assert_close(output_image_size, size) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): - mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - mask[:, :, 0] = 1 - - out_mask = F.horizontal_flip_mask(mask) - - expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - expected_mask[:, :, -1] = 1 - torch.testing.assert_close(out_mask, expected_mask) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): - mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - mask[:, 0, :] = 1 - - out_mask = F.vertical_flip_mask(mask) - - expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - expected_mask[:, -1, :] = 1 - torch.testing.assert_close(out_mask, expected_mask) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize( - "format", - [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], -) -@pytest.mark.parametrize( - "top, left, height, width, size", - [ - [0, 0, 30, 30, (60, 60)], - [-5, 5, 35, 45, (32, 34)], - ], -) -def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size): - def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): - # bbox should be xyxy - bbox[0] = (bbox[0] - left_) * size_[1] / width_ - bbox[1] = (bbox[1] - top_) * size_[0] / height_ - bbox[2] = (bbox[2] - left_) * size_[1] / width_ - bbox[3] = (bbox[3] - top_) * size_[0] / height_ - return bbox - - image_size = (100, 100) - # xyxy format - in_boxes = [ - [10.0, 10.0, 20.0, 20.0], - [5.0, 10.0, 15.0, 20.0], - ] - expected_bboxes = [] - for in_box in in_boxes: - expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) - expected_bboxes = torch.tensor(expected_bboxes, device=device) - - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device - ) - if format != features.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) - - output_boxes, output_image_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) - - if format != features.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) - - torch.testing.assert_close(output_boxes, expected_bboxes) - torch.testing.assert_close(output_image_size, size) - - -def _parse_padding(padding): - if isinstance(padding, int): - return [padding] * 4 - if isinstance(padding, list): - if len(padding) == 1: - return padding * 4 - if len(padding) == 2: - return padding * 2 # [left, up, right, down] - - return padding - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) -def test_correctness_pad_bounding_box(device, padding): - def _compute_expected_bbox(bbox, padding_): - pad_left, pad_up, _, _ = _parse_padding(padding_) - - bbox_format = bbox.format - bbox_dtype = bbox.dtype - bbox = convert_format_bounding_box(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY) - - bbox[0::2] += pad_left - bbox[1::2] += pad_up - - bbox = convert_format_bounding_box( - bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False - ) - if bbox.dtype != bbox_dtype: - # Temporary cast to original dtype - # e.g. float32 -> int - bbox = bbox.to(bbox_dtype) - return bbox - - def _compute_expected_image_size(bbox, padding_): - pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) - height, width = bbox.image_size - return height + pad_up + pad_down, width + pad_left + pad_right - - for bboxes in make_bounding_boxes(): - bboxes = bboxes.to(device) - bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size - - output_boxes, output_image_size = F.pad_bounding_box( - bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding - ) - - torch.testing.assert_close(output_image_size, _compute_expected_image_size(bboxes, padding)) - - if bboxes.ndim < 2 or bboxes.shape[0] == 0: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, padding)) - - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -def test_correctness_pad_segmentation_mask_on_fixed_input(device): - mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) - - out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1]) - - expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) - expected_mask[:, 1:-1, 1:-1] = 1 - torch.testing.assert_close(out_mask, expected_mask) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize( - "startpoints, endpoints", - [ - [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], - [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], - [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], - ], -) -def test_correctness_perspective_bounding_box(device, startpoints, endpoints): - def _compute_expected_bbox(bbox, pcoeffs_): - m1 = np.array( - [ - [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], - [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]], - ] - ) - m2 = np.array( - [ - [pcoeffs_[6], pcoeffs_[7], 1.0], - [pcoeffs_[6], pcoeffs_[7], 1.0], - ] - ) - - bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY - ) - points = np.array( - [ - [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], - ] - ) - numer = np.matmul(points, m1.T) - denom = np.matmul(points, m2.T) - transformed_points = numer / denom - out_bbox = [ - np.min(transformed_points[:, 0]), - np.min(transformed_points[:, 1]), - np.max(transformed_points[:, 0]), - np.max(transformed_points[:, 1]), - ] - out_bbox = features.BoundingBox( - np.array(out_bbox), - format=features.BoundingBoxFormat.XYXY, - image_size=bbox.image_size, - dtype=bbox.dtype, - device=bbox.device, - ) - return convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False - ) - - image_size = (32, 38) - - pcoeffs = _get_perspective_coeffs(startpoints, endpoints) - inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) - - for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): - bboxes = bboxes.to(device) - bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size - - output_bboxes = F.perspective_bounding_box( - bboxes, - bboxes_format, - perspective_coeffs=pcoeffs, - ) - - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize( - "output_size", - [(18, 18), [18, 15], (16, 19), [12], [46, 48]], -) -def test_correctness_center_crop_bounding_box(device, output_size): - def _compute_expected_bbox(bbox, output_size_): - format_ = bbox.format - image_size_ = bbox.image_size - bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH) - - if len(output_size_) == 1: - output_size_.append(output_size_[-1]) - - cy = int(round((image_size_[0] - output_size_[0]) * 0.5)) - cx = int(round((image_size_[1] - output_size_[1]) * 0.5)) - out_bbox = [ - bbox[0].item() - cx, - bbox[1].item() - cy, - bbox[2].item(), - bbox[3].item(), - ] - out_bbox = features.BoundingBox( - out_bbox, - format=features.BoundingBoxFormat.XYWH, - image_size=output_size_, - dtype=bbox.dtype, - device=bbox.device, - ) - return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) - - for bboxes in make_bounding_boxes(extra_dims=((4,),)): - bboxes = bboxes.to(device) - bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size - - output_boxes, output_image_size = F.center_crop_bounding_box( - bboxes, bboxes_format, bboxes_image_size, output_size - ) - - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) - - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_boxes, expected_bboxes) - torch.testing.assert_close(output_image_size, output_size) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) -def test_correctness_center_crop_mask(device, output_size): - def _compute_expected_mask(mask, output_size): - crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]] - - _, image_height, image_width = mask.shape - if crop_width > image_height or crop_height > image_width: - padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - mask = F.pad_image_tensor(mask, padding, fill=0) - - left = round((image_width - crop_width) * 0.5) - top = round((image_height - crop_height) * 0.5) - - return mask[:, top : top + crop_height, left : left + crop_width] - - mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device) - actual = F.center_crop_mask(mask, output_size) - - expected = _compute_expected_mask(mask, output_size) - torch.testing.assert_close(expected, actual) - - -# Copied from test/test_functional_tensor.py -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("image_size", ("small", "large")) -@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) -@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) -def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, sigma): - fn = F.gaussian_blur_image_tensor - - # true_cv2_results = { - # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) - # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8) - # "3_3_0.8": ... - # # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5) - # "3_3_0.5": ... - # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8) - # "3_5_0.8": ... - # # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5) - # "3_5_0.5": ... - # # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28)) - # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7) - # "23_23_1.7": ... - # } - p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") - true_cv2_results = torch.load(p) - - if image_size == "small": - tensor = ( - torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) - ) - else: - tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device) - - if dt == torch.float16 and device == "cpu": - # skip float16 on CPU case - return - - if dt is not None: - tensor = tensor.to(dtype=dt) - - _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize - _sigma = sigma[0] if sigma is not None else None - shape = tensor.shape - gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}" - if gt_key not in true_cv2_results: - return - - true_out = ( - torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) - ) - - image = features.Image(tensor) - - out = fn(image, kernel_size=ksize, sigma=sigma) - torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") - - -def test_normalize_output_type(): - inpt = torch.rand(1, 3, 32, 32) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor - torch.testing.assert_close(inpt - 0.5, output) - - inpt = make_image(color_space=features.ColorSpace.RGB) - output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) - assert type(output) is torch.Tensor - torch.testing.assert_close(inpt - 0.5, output) - - -@pytest.mark.parametrize( - "inpt", - [ - 127 * np.ones((32, 32, 3), dtype="uint8"), - PIL.Image.new("RGB", (32, 32), 122), - ], -) -def test_to_image_tensor(inpt): - output = F.to_image_tensor(inpt) - assert isinstance(output, torch.Tensor) - - assert np.asarray(inpt).sum() == output.sum().item() - - if isinstance(inpt, PIL.Image.Image): - # we can't check this option - # as PIL -> numpy is always copying - return - - inpt[0, 0, 0] = 11 - assert output[0, 0, 0] == 11 - - -@pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8), - 127 * np.ones((32, 32, 3), dtype="uint8"), - ], -) -@pytest.mark.parametrize("mode", [None, "RGB"]) -def test_to_image_pil(inpt, mode): - output = F.to_image_pil(inpt, mode=mode) - assert isinstance(output, PIL.Image.Image) - - assert np.asarray(inpt).sum() == np.asarray(output).sum() diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py deleted file mode 100644 index 9a8ed67dde2..00000000000 --- a/test/test_prototype_transforms_utils.py +++ /dev/null @@ -1,83 +0,0 @@ -import PIL.Image -import pytest - -import torch - -from prototype_common_utils import make_bounding_box, make_detection_mask, make_image - -from torchvision.prototype import features -from torchvision.prototype.transforms._utils import has_all, has_any -from torchvision.prototype.transforms.functional import to_image_pil - - -IMAGE = make_image(color_space=features.ColorSpace.RGB) -BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) -MASK = make_detection_mask(size=IMAGE.image_size) - - -@pytest.mark.parametrize( - ("sample", "types", "expected"), - [ - ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), - ((MASK,), (features.Image, features.BoundingBox), False), - ((BOUNDING_BOX,), (features.Image, features.Mask), False), - ((IMAGE,), (features.BoundingBox, features.Mask), False), - ( - (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), - True, - ), - ((), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), - ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), - ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), - ], -) -def test_has_any(sample, types, expected): - assert has_any(sample, *types) is expected - - -@pytest.mark.parametrize( - ("sample", "types", "expected"), - [ - ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), - ( - (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), - True, - ), - ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False), - ((BOUNDING_BOX, MASK), (features.Image, features.Mask), False), - ((IMAGE, MASK), (features.BoundingBox, features.Mask), False), - ( - (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), - True, - ), - ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False), - ( - (IMAGE, BOUNDING_BOX, MASK), - (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),), - True, - ), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ], -) -def test_has_all(sample, types, expected): - assert has_all(sample, *types) is expected From a79fd8e8ac6cb33f379eda23b7a1bfedd2386c5a Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 4 Oct 2022 02:06:49 +0100 Subject: [PATCH 2/8] Remove torchvision/prototype dir --- torchvision/prototype/__init__.py | 1 - torchvision/prototype/datasets/__init__.py | 15 - torchvision/prototype/datasets/_api.py | 65 - .../prototype/datasets/_builtin/README.md | 340 ---- .../prototype/datasets/_builtin/__init__.py | 22 - .../prototype/datasets/_builtin/caltech.py | 207 --- .../datasets/_builtin/caltech101.categories | 101 -- .../datasets/_builtin/caltech256.categories | 257 --- .../prototype/datasets/_builtin/celeba.py | 195 --- .../prototype/datasets/_builtin/cifar.py | 139 -- .../datasets/_builtin/cifar10.categories | 10 - .../datasets/_builtin/cifar100.categories | 100 -- .../prototype/datasets/_builtin/clevr.py | 105 -- .../datasets/_builtin/coco.categories | 91 - .../prototype/datasets/_builtin/coco.py | 270 --- .../datasets/_builtin/country211.categories | 211 --- .../prototype/datasets/_builtin/country211.py | 81 - .../datasets/_builtin/cub200.categories | 200 --- .../prototype/datasets/_builtin/cub200.py | 258 --- .../datasets/_builtin/dtd.categories | 47 - .../prototype/datasets/_builtin/dtd.py | 139 -- .../prototype/datasets/_builtin/eurosat.py | 66 - .../prototype/datasets/_builtin/fer2013.py | 63 - .../datasets/_builtin/food101.categories | 101 -- .../prototype/datasets/_builtin/food101.py | 97 -- .../prototype/datasets/_builtin/gtsrb.py | 111 -- .../datasets/_builtin/imagenet.categories | 1000 ----------- .../prototype/datasets/_builtin/imagenet.py | 223 --- .../prototype/datasets/_builtin/mnist.py | 415 ----- .../_builtin/oxford-iiit-pet.categories | 37 - .../datasets/_builtin/oxford_iiit_pet.py | 146 -- .../prototype/datasets/_builtin/pcam.py | 126 -- .../datasets/_builtin/sbd.categories | 20 - .../prototype/datasets/_builtin/sbd.py | 153 -- .../prototype/datasets/_builtin/semeion.py | 54 - .../_builtin/stanford-cars.categories | 196 --- .../datasets/_builtin/stanford_cars.py | 116 -- .../prototype/datasets/_builtin/svhn.py | 83 - .../prototype/datasets/_builtin/usps.py | 69 - .../datasets/_builtin/voc.categories | 21 - .../prototype/datasets/_builtin/voc.py | 219 --- torchvision/prototype/datasets/_folder.py | 65 - torchvision/prototype/datasets/_home.py | 28 - torchvision/prototype/datasets/benchmark.py | 661 -------- .../datasets/generate_category_files.py | 61 - .../prototype/datasets/utils/__init__.py | 3 - .../prototype/datasets/utils/_dataset.py | 57 - .../prototype/datasets/utils/_internal.py | 196 --- .../prototype/datasets/utils/_resource.py | 236 --- torchvision/prototype/features/__init__.py | 15 - .../prototype/features/_bounding_box.py | 176 -- torchvision/prototype/features/_encoded.py | 50 - torchvision/prototype/features/_feature.py | 259 --- torchvision/prototype/features/_image.py | 295 ---- torchvision/prototype/features/_label.py | 74 - torchvision/prototype/features/_mask.py | 108 -- torchvision/prototype/models/__init__.py | 1 - .../prototype/models/depth/__init__.py | 1 - .../prototype/models/depth/stereo/__init__.py | 2 - .../models/depth/stereo/crestereo.py | 1460 ----------------- .../models/depth/stereo/raft_stereo.py | 750 --------- torchvision/prototype/transforms/__init__.py | 46 - torchvision/prototype/transforms/_augment.py | 381 ----- .../prototype/transforms/_auto_augment.py | 525 ------ torchvision/prototype/transforms/_color.py | 189 --- .../prototype/transforms/_container.py | 93 -- .../prototype/transforms/_deprecated.py | 88 - torchvision/prototype/transforms/_geometry.py | 889 ---------- torchvision/prototype/transforms/_meta.py | 68 - torchvision/prototype/transforms/_misc.py | 174 -- torchvision/prototype/transforms/_presets.py | 74 - .../prototype/transforms/_transform.py | 72 - .../prototype/transforms/_type_conversion.py | 71 - torchvision/prototype/transforms/_utils.py | 117 -- .../transforms/functional/__init__.py | 129 -- .../transforms/functional/_augment.py | 35 - .../prototype/transforms/functional/_color.py | 145 -- .../transforms/functional/_deprecated.py | 64 - .../transforms/functional/_geometry.py | 1314 --------------- .../prototype/transforms/functional/_meta.py | 228 --- .../prototype/transforms/functional/_misc.py | 75 - .../transforms/functional/_type_conversion.py | 45 - torchvision/prototype/utils/__init__.py | 1 - torchvision/prototype/utils/_internal.py | 126 -- 84 files changed, 15617 deletions(-) delete mode 100644 torchvision/prototype/__init__.py delete mode 100644 torchvision/prototype/datasets/__init__.py delete mode 100644 torchvision/prototype/datasets/_api.py delete mode 100644 torchvision/prototype/datasets/_builtin/README.md delete mode 100644 torchvision/prototype/datasets/_builtin/__init__.py delete mode 100644 torchvision/prototype/datasets/_builtin/caltech.py delete mode 100644 torchvision/prototype/datasets/_builtin/caltech101.categories delete mode 100644 torchvision/prototype/datasets/_builtin/caltech256.categories delete mode 100644 torchvision/prototype/datasets/_builtin/celeba.py delete mode 100644 torchvision/prototype/datasets/_builtin/cifar.py delete mode 100644 torchvision/prototype/datasets/_builtin/cifar10.categories delete mode 100644 torchvision/prototype/datasets/_builtin/cifar100.categories delete mode 100644 torchvision/prototype/datasets/_builtin/clevr.py delete mode 100644 torchvision/prototype/datasets/_builtin/coco.categories delete mode 100644 torchvision/prototype/datasets/_builtin/coco.py delete mode 100644 torchvision/prototype/datasets/_builtin/country211.categories delete mode 100644 torchvision/prototype/datasets/_builtin/country211.py delete mode 100644 torchvision/prototype/datasets/_builtin/cub200.categories delete mode 100644 torchvision/prototype/datasets/_builtin/cub200.py delete mode 100644 torchvision/prototype/datasets/_builtin/dtd.categories delete mode 100644 torchvision/prototype/datasets/_builtin/dtd.py delete mode 100644 torchvision/prototype/datasets/_builtin/eurosat.py delete mode 100644 torchvision/prototype/datasets/_builtin/fer2013.py delete mode 100644 torchvision/prototype/datasets/_builtin/food101.categories delete mode 100644 torchvision/prototype/datasets/_builtin/food101.py delete mode 100644 torchvision/prototype/datasets/_builtin/gtsrb.py delete mode 100644 torchvision/prototype/datasets/_builtin/imagenet.categories delete mode 100644 torchvision/prototype/datasets/_builtin/imagenet.py delete mode 100644 torchvision/prototype/datasets/_builtin/mnist.py delete mode 100644 torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories delete mode 100644 torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py delete mode 100644 torchvision/prototype/datasets/_builtin/pcam.py delete mode 100644 torchvision/prototype/datasets/_builtin/sbd.categories delete mode 100644 torchvision/prototype/datasets/_builtin/sbd.py delete mode 100644 torchvision/prototype/datasets/_builtin/semeion.py delete mode 100644 torchvision/prototype/datasets/_builtin/stanford-cars.categories delete mode 100644 torchvision/prototype/datasets/_builtin/stanford_cars.py delete mode 100644 torchvision/prototype/datasets/_builtin/svhn.py delete mode 100644 torchvision/prototype/datasets/_builtin/usps.py delete mode 100644 torchvision/prototype/datasets/_builtin/voc.categories delete mode 100644 torchvision/prototype/datasets/_builtin/voc.py delete mode 100644 torchvision/prototype/datasets/_folder.py delete mode 100644 torchvision/prototype/datasets/_home.py delete mode 100644 torchvision/prototype/datasets/benchmark.py delete mode 100644 torchvision/prototype/datasets/generate_category_files.py delete mode 100644 torchvision/prototype/datasets/utils/__init__.py delete mode 100644 torchvision/prototype/datasets/utils/_dataset.py delete mode 100644 torchvision/prototype/datasets/utils/_internal.py delete mode 100644 torchvision/prototype/datasets/utils/_resource.py delete mode 100644 torchvision/prototype/features/__init__.py delete mode 100644 torchvision/prototype/features/_bounding_box.py delete mode 100644 torchvision/prototype/features/_encoded.py delete mode 100644 torchvision/prototype/features/_feature.py delete mode 100644 torchvision/prototype/features/_image.py delete mode 100644 torchvision/prototype/features/_label.py delete mode 100644 torchvision/prototype/features/_mask.py delete mode 100644 torchvision/prototype/models/__init__.py delete mode 100644 torchvision/prototype/models/depth/__init__.py delete mode 100644 torchvision/prototype/models/depth/stereo/__init__.py delete mode 100644 torchvision/prototype/models/depth/stereo/crestereo.py delete mode 100644 torchvision/prototype/models/depth/stereo/raft_stereo.py delete mode 100644 torchvision/prototype/transforms/__init__.py delete mode 100644 torchvision/prototype/transforms/_augment.py delete mode 100644 torchvision/prototype/transforms/_auto_augment.py delete mode 100644 torchvision/prototype/transforms/_color.py delete mode 100644 torchvision/prototype/transforms/_container.py delete mode 100644 torchvision/prototype/transforms/_deprecated.py delete mode 100644 torchvision/prototype/transforms/_geometry.py delete mode 100644 torchvision/prototype/transforms/_meta.py delete mode 100644 torchvision/prototype/transforms/_misc.py delete mode 100644 torchvision/prototype/transforms/_presets.py delete mode 100644 torchvision/prototype/transforms/_transform.py delete mode 100644 torchvision/prototype/transforms/_type_conversion.py delete mode 100644 torchvision/prototype/transforms/_utils.py delete mode 100644 torchvision/prototype/transforms/functional/__init__.py delete mode 100644 torchvision/prototype/transforms/functional/_augment.py delete mode 100644 torchvision/prototype/transforms/functional/_color.py delete mode 100644 torchvision/prototype/transforms/functional/_deprecated.py delete mode 100644 torchvision/prototype/transforms/functional/_geometry.py delete mode 100644 torchvision/prototype/transforms/functional/_meta.py delete mode 100644 torchvision/prototype/transforms/functional/_misc.py delete mode 100644 torchvision/prototype/transforms/functional/_type_conversion.py delete mode 100644 torchvision/prototype/utils/__init__.py delete mode 100644 torchvision/prototype/utils/_internal.py diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py deleted file mode 100644 index bef5ecc411d..00000000000 --- a/torchvision/prototype/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import datasets, features, models, transforms, utils diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py deleted file mode 100644 index 848d9135c2f..00000000000 --- a/torchvision/prototype/datasets/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -try: - import torchdata -except ModuleNotFoundError: - raise ModuleNotFoundError( - "`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). " - "You can install it with `pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu" - ) from None - -from . import utils -from ._home import home - -# Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip -from ._folder import from_data_folder, from_image_folder -from ._builtin import * diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py deleted file mode 100644 index f6f06c60a21..00000000000 --- a/torchvision/prototype/datasets/_api.py +++ /dev/null @@ -1,65 +0,0 @@ -import pathlib -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union - -from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset -from torchvision.prototype.utils._internal import add_suggestion - - -T = TypeVar("T") -D = TypeVar("D", bound=Type[Dataset]) - -BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} - - -def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: - def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: - BUILTIN_INFOS[name] = fn() - return fn - - return wrapper - - -BUILTIN_DATASETS = {} - - -def register_dataset(name: str) -> Callable[[D], D]: - def wrapper(dataset_cls: D) -> D: - BUILTIN_DATASETS[name] = dataset_cls - return dataset_cls - - return wrapper - - -def list_datasets() -> List[str]: - return sorted(BUILTIN_DATASETS.keys()) - - -def find(dct: Dict[str, T], name: str) -> T: - name = name.lower() - try: - return dct[name] - except KeyError as error: - raise ValueError( - add_suggestion( - f"Unknown dataset '{name}'.", - word=name, - possibilities=dct.keys(), - alternative_hint=lambda _: ( - "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." - ), - ) - ) from error - - -def info(name: str) -> Dict[str, Any]: - return find(BUILTIN_INFOS, name) - - -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset: - dataset_cls = find(BUILTIN_DATASETS, name) - - if root is None: - root = pathlib.Path(home()) / name - - return dataset_cls(root, **config) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md deleted file mode 100644 index 05d61c6870e..00000000000 --- a/torchvision/prototype/datasets/_builtin/README.md +++ /dev/null @@ -1,340 +0,0 @@ -# How to add new built-in prototype datasets - -As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means -that this document will also change a lot. - -If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented -there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out. - -Finally, `from torchvision.prototype import datasets` is implied below. - -## Implementation - -Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` -that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that -module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in -detail below: - -```python -import pathlib -from typing import Any, BinaryIO, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe -from torchvision.prototype.datasets.utils import Dataset, OnlineResource - -from .._api import register_dataset, register_info - -NAME = "my-dataset" - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - ... - ) - -@register_dataset(NAME) -class MyDataset(Dataset): - def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None: - ... - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - ... - - def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]: - ... - - def __len__(self) -> int: - ... -``` - -In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a -dictionary of static information. The most common use case is to provide human-readable categories. -[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. - -Finally, both the dataset class and the info function need to be registered on the API with the respective decorators. -With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively. - -### `__init__(self, root, *, ..., skip_integrity_check = False)` - -Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the -base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as -setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke -the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with -an underscore. - -If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base -class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically -checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to -avoid missing dependencies at import time. - -### `_resources(self)` - -Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be -build. The download will happen automatically. - -Currently, the following `OnlineResource`'s are supported: - -- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL. -- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`. -- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them - manually. If the file does not exist, an error will be raised with the supplied instructions. -- `KaggleDownloadResource`: Used for files that are available on Kaggle. This inherits from `ManualDownloadResource`. - -Although optional in general, all resources used in the built-in datasets should comprise -[SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the -download. You can compute the checksum with system utilities e.g `sha256-sum`, or this snippet: - -```python -import hashlib - -def sha256sum(path, chunk_size=1024 * 1024): - checksum = hashlib.sha256() - with open(path, "rb") as f: - for chunk in iter(lambda: f.read(chunk_size), b""): - checksum.update(chunk) - print(checksum.hexdigest()) -``` - -### `_datapipe(self, resource_dps)` - -This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared -to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone -that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything -with them besides iterating. - -Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are: - -- `Mapper`: Apply a callable to every item in the datapipe. -- `Filter`: Keep only items that satisfy a condition. -- `Demultiplexer`: Split a datapipe into multiple ones. -- `IterKeyZipper`: Merge two datapipes into one. - -All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable -needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated -to add one. See the MNIST or CelebA datasets for example. - -`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return -value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain -tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one -of such tuples for the file specified by the resource. - -Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and -`Grouper`. There are two issues with that: - -1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely. -2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime. - -Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than -trying to zip already loaded images. - -There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and -`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding -should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` -and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. - -Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the -names (yet!). - -### `__len__` - -This returns an integer denoting the number of samples that can be drawn from the dataset. Please use -[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the -readability. For example, `1_281_167` vs. `1281167`. - -If there are only two different numbers, a simple `if` / `else` is fine: - -```py -def __len__(self): - return 12_345 if self._split == "train" else 6_789 -``` - -If there are more options, using a dictionary usually is the most readable option: - -```py -def __len__(self): - return { - "train": 3, - "val": 2, - "test": 1, - }[self._split] -``` - -If the number of samples depends on more than one parameter, you can use tuples as dictionary keys: - -```py -def __len__(self): - return { - ("train", "bar"): 4, - ("train", "baz"): 3, - ("test", "bar"): 2, - ("test", "baz"): 1, - }[(self._split, self._foo)] -``` - -The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the -development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way -is to define a dummy method like - -```py -def __len__(self): - return 1 -``` - -and only fill it with the correct data if the implementation is otherwise finished. -[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples. - -## Tests - -To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data. -This mock-up should resemble the original data as close as necessary, while containing only few examples. - -To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the -same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function". -Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset -will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options, -you can use the `combinations_grid()` helper function, e.g. -`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`. - -In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass -the `name` parameter to `@register_mock` - -```py -# this is defined in torchvision/prototype/datasets/_builtin -@register_dataset("my-dataset") -class MyDataset(Dataset): - ... - -@register_mock(name="my-dataset", configs=...) -def my_dataset(root, config): - ... -``` - -The mock data function receives two arguments: - -- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data - needs to be placed. -- `config`: The configuration to generate the data for. This is one of the dictionaries defined in - `@register_mock(configs=...)` - -The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if -the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of -the current `config`. Although this seems odd at first, this is important. Consider the following original data setup: - -``` -root -├── test -│ ├── test_image0.jpg -│ ... -└── train - ├── train_image0.jpg - ... -``` - -For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to -load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is -present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in -`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for -the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data. - -For datasets that are ported from the old API, we already have some mock data in -[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there -and have a look at the `inject_fake_data` function. There are a few differences though: - -- `tmp_dir` corresponds to `root`, but is a `str` rather than a - [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like - `folder = pathlib.Path(tmp_dir)`. This is not needed. -- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the - new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files - specified in the dataset. -- As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config. - Make sure you follow the instructions above. - -The function should return an integer indicating the number of samples in the dataset for the current `config`. -Preferably, this number should be different for different `config`'s to have more confidence in the dataset -implementation. - -Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets.py -k {name}`. - -## FAQ - -### How do I start? - -Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do -`return resources_dp[0]` to get started. Then import the dataset class in -`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be -instantiable via `datasets.load("mydataset")`. On a separate script, try something like - -```py -from torchvision.prototype import datasets - -dataset = datasets.load("mydataset") -for sample in dataset: - print(sample) # this is the content of an item in datapipe returned by _datapipe() - break -# Or you can also inspect the sample in a debugger -``` - -This will give you an idea of what the first datapipe in `resources_dp` contains. You can also do that with -`resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these -datapipes and return the appropriate dictionary format. - -### How do I handle a dataset that defines many categories? - -As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more -categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a -category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file` -function and pass it `$NAME`. - -In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where -each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method -should return a sequence of strings representing the category names. In the method body, you'll have to manually load -the resources, e.g. - -```py -resources = self._resources() -dp = resources[0].load(self._root) -``` - -Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes -sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that. - -To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. - -### What if a resource file forms an I/O bottleneck? - -In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if -the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the -`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be -preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also -accepts `"decompress"` and `"extract"` to handle these common scenarios. - -### How do I compute the number of samples? - -Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way -than to iterate over the dataset and count the number of samples: - -```py -import itertools -from torchvision.prototype import datasets - - -def combinations_grid(**kwargs): - return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] - - -# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there -configs = combinations_grid(split=("train", "test"), foo=("bar", "baz")) - -for config in configs: - dataset = datasets.load("my-dataset", **config) - - num_samples = 0 - for _ in dataset: - num_samples += 1 - - print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples) -``` - -To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation -files. diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py deleted file mode 100644 index d84e9af9fc4..00000000000 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .caltech import Caltech101, Caltech256 -from .celeba import CelebA -from .cifar import Cifar10, Cifar100 -from .clevr import CLEVR -from .coco import Coco -from .country211 import Country211 -from .cub200 import CUB200 -from .dtd import DTD -from .eurosat import EuroSAT -from .fer2013 import FER2013 -from .food101 import Food101 -from .gtsrb import GTSRB -from .imagenet import ImageNet -from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST -from .oxford_iiit_pet import OxfordIIITPet -from .pcam import PCAM -from .sbd import SBD -from .semeion import SEMEION -from .stanford_cars import StanfordCars -from .svhn import SVHN -from .usps import USPS -from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py deleted file mode 100644 index a00bf2e2cc9..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ /dev/null @@ -1,207 +0,0 @@ -import pathlib -import re -from typing import Any, BinaryIO, Dict, List, Tuple, Union - -import numpy as np -from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - read_categories_file, - read_mat, -) -from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label - -from .._api import register_dataset, register_info - - -@register_info("caltech101") -def _caltech101_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("caltech101")) - - -@register_dataset("caltech101") -class Caltech101(Dataset): - """ - - **homepage**: https://data.caltech.edu/records/20086 - - **dependencies**: - - _ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - skip_integrity_check: bool = False, - ) -> None: - self._categories = _caltech101_info()["categories"] - - super().__init__( - root, - dependencies=("scipy",), - skip_integrity_check=skip_integrity_check, - ) - - def _resources(self) -> List[OnlineResource]: - images = GDriveResource( - "137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", - file_name="101_ObjectCategories.tar.gz", - sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", - preprocess="decompress", - ) - anns = GDriveResource( - "175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", - file_name="Annotations.tar", - sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8", - ) - return [images, anns] - - _IMAGES_NAME_PATTERN = re.compile(r"image_(?P\d+)[.]jpg") - _ANNS_NAME_PATTERN = re.compile(r"annotation_(?P\d+)[.]mat") - _ANNS_CATEGORY_MAP = { - "Faces_2": "Faces", - "Faces_3": "Faces_easy", - "Motorbikes_16": "Motorbikes", - "Airplanes_Side_2": "airplanes", - } - - def _is_not_background_image(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.parent.name != "BACKGROUND_Google" - - def _is_ann(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return bool(self._ANNS_NAME_PATTERN.match(path.name)) - - def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - if category in self._ANNS_CATEGORY_MAP: - category = self._ANNS_CATEGORY_MAP[category] - - id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _prepare_sample( - self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] - ) -> Dict[str, Any]: - key, (image_data, ann_data) = data - category, _ = key - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - image = EncodedImage.from_file(image_buffer) - ann = read_mat(ann_buffer) - - return dict( - label=Label.from_category(category, categories=self._categories), - image_path=image_path, - image=image, - ann_path=ann_path, - bounding_box=BoundingBox( - ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size - ), - contour=_Feature(ann["obj_contour"].T), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, anns_dp = resource_dps - - images_dp = Filter(images_dp, self._is_not_background_image) - images_dp = hint_shuffling(images_dp) - images_dp = hint_sharding(images_dp) - - anns_dp = Filter(anns_dp, self._is_ann) - - dp = IterKeyZipper( - images_dp, - anns_dp, - key_fn=self._images_key_fn, - ref_key_fn=self._anns_key_fn, - buffer_size=INFINITE_BUFFER_SIZE, - keep_key=True, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 8677 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, self._is_not_background_image) - - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) - - -@register_info("caltech256") -def _caltech256_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("caltech256")) - - -@register_dataset("caltech256") -class Caltech256(Dataset): - """ - - **homepage**: https://data.caltech.edu/records/20087 - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - skip_integrity_check: bool = False, - ) -> None: - self._categories = _caltech256_info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - GDriveResource( - "1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", - file_name="256_ObjectCategories.tar", - sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e", - ) - ] - - def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name != "RENAME2" - - def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: - path, buffer = data - - return dict( - path=path, - image=EncodedImage.from_file(buffer), - label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, self._is_not_rogue_file) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 30607 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dir_names = {pathlib.Path(path).parent.name for path, _ in dp} - - return [name.split(".")[1] for name in sorted(dir_names)] diff --git a/torchvision/prototype/datasets/_builtin/caltech101.categories b/torchvision/prototype/datasets/_builtin/caltech101.categories deleted file mode 100644 index d5c18654b4e..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech101.categories +++ /dev/null @@ -1,101 +0,0 @@ -Faces -Faces_easy -Leopards -Motorbikes -accordion -airplanes -anchor -ant -barrel -bass -beaver -binocular -bonsai -brain -brontosaurus -buddha -butterfly -camera -cannon -car_side -ceiling_fan -cellphone -chair -chandelier -cougar_body -cougar_face -crab -crayfish -crocodile -crocodile_head -cup -dalmatian -dollar_bill -dolphin -dragonfly -electric_guitar -elephant -emu -euphonium -ewer -ferry -flamingo -flamingo_head -garfield -gerenuk -gramophone -grand_piano -hawksbill -headphone -hedgehog -helicopter -ibis -inline_skate -joshua_tree -kangaroo -ketch -lamp -laptop -llama -lobster -lotus -mandolin -mayfly -menorah -metronome -minaret -nautilus -octopus -okapi -pagoda -panda -pigeon -pizza -platypus -pyramid -revolver -rhino -rooster -saxophone -schooner -scissors -scorpion -sea_horse -snoopy -soccer_ball -stapler -starfish -stegosaurus -stop_sign -strawberry -sunflower -tick -trilobite -umbrella -watch -water_lilly -wheelchair -wild_cat -windsor_chair -wrench -yin_yang diff --git a/torchvision/prototype/datasets/_builtin/caltech256.categories b/torchvision/prototype/datasets/_builtin/caltech256.categories deleted file mode 100644 index 82128efba97..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech256.categories +++ /dev/null @@ -1,257 +0,0 @@ -ak47 -american-flag -backpack -baseball-bat -baseball-glove -basketball-hoop -bat -bathtub -bear -beer-mug -billiards -binoculars -birdbath -blimp -bonsai-101 -boom-box -bowling-ball -bowling-pin -boxing-glove -brain-101 -breadmaker -buddha-101 -bulldozer -butterfly -cactus -cake -calculator -camel -cannon -canoe -car-tire -cartman -cd -centipede -cereal-box -chandelier-101 -chess-board -chimp -chopsticks -cockroach -coffee-mug -coffin -coin -comet -computer-keyboard -computer-monitor -computer-mouse -conch -cormorant -covered-wagon -cowboy-hat -crab-101 -desk-globe -diamond-ring -dice -dog -dolphin-101 -doorknob -drinking-straw -duck -dumb-bell -eiffel-tower -electric-guitar-101 -elephant-101 -elk -ewer-101 -eyeglasses -fern -fighter-jet -fire-extinguisher -fire-hydrant -fire-truck -fireworks -flashlight -floppy-disk -football-helmet -french-horn -fried-egg -frisbee -frog -frying-pan -galaxy -gas-pump -giraffe -goat -golden-gate-bridge -goldfish -golf-ball -goose -gorilla -grand-piano-101 -grapes -grasshopper -guitar-pick -hamburger -hammock -harmonica -harp -harpsichord -hawksbill-101 -head-phones -helicopter-101 -hibiscus -homer-simpson -horse -horseshoe-crab -hot-air-balloon -hot-dog -hot-tub -hourglass -house-fly -human-skeleton -hummingbird -ibis-101 -ice-cream-cone -iguana -ipod -iris -jesus-christ -joy-stick -kangaroo-101 -kayak -ketch-101 -killer-whale -knife -ladder -laptop-101 -lathe -leopards-101 -license-plate -lightbulb -light-house -lightning -llama-101 -mailbox -mandolin -mars -mattress -megaphone -menorah-101 -microscope -microwave -minaret -minotaur -motorbikes-101 -mountain-bike -mushroom -mussels -necktie -octopus -ostrich -owl -palm-pilot -palm-tree -paperclip -paper-shredder -pci-card -penguin -people -pez-dispenser -photocopier -picnic-table -playing-card -porcupine -pram -praying-mantis -pyramid -raccoon -radio-telescope -rainbow -refrigerator -revolver-101 -rifle -rotary-phone -roulette-wheel -saddle -saturn -school-bus -scorpion-101 -screwdriver -segway -self-propelled-lawn-mower -sextant -sheet-music -skateboard -skunk -skyscraper -smokestack -snail -snake -sneaker -snowmobile -soccer-ball -socks -soda-can -spaghetti -speed-boat -spider -spoon -stained-glass -starfish-101 -steering-wheel -stirrups -sunflower-101 -superman -sushi -swan -swiss-army-knife -sword -syringe -tambourine -teapot -teddy-bear -teepee -telephone-box -tennis-ball -tennis-court -tennis-racket -theodolite -toaster -tomato -tombstone -top-hat -touring-bike -tower-pisa -traffic-light -treadmill -triceratops -tricycle -trilobite-101 -tripod -t-shirt -tuning-fork -tweezer -umbrella-101 -unicorn -vcr -video-projector -washing-machine -watch-101 -waterfall -watermelon -welding-mask -wheelbarrow -windmill -wine-bottle -xylophone -yarmulke -yo-yo -zebra -airplanes-101 -car-side-101 -faces-easy-101 -greyhound -tennis-shoes -toad -clutter diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py deleted file mode 100644 index e42657e826e..00000000000 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ /dev/null @@ -1,195 +0,0 @@ -import csv -import pathlib -from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, -) -from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label - -from .._api import register_dataset, register_info - -csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) - - -class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): - def __init__( - self, - datapipe: IterDataPipe[Tuple[Any, BinaryIO]], - *, - fieldnames: Optional[Sequence[str]] = None, - ) -> None: - self.datapipe = datapipe - self.fieldnames = fieldnames - - def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: - for _, file in self.datapipe: - file = (line.decode() for line in file) - - if self.fieldnames: - fieldnames = self.fieldnames - else: - # The first row is skipped, because it only contains the number of samples - next(file) - - # Empty field names are filtered out, because some files have an extra white space after the header - # line, which is recognized as extra column - fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name] - # Some files do not include a label for the image ID column - if fieldnames[0] != "image_id": - fieldnames.insert(0, "image_id") - - for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"): - yield line.pop("image_id"), line - - -NAME = "celeba" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict() - - -@register_dataset(NAME) -class CelebA(Dataset): - """ - - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - splits = GDriveResource( - "0B7EVK8r0v71pY0NSMzRuSXJEVkk", - sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", - file_name="list_eval_partition.txt", - ) - images = GDriveResource( - "0B7EVK8r0v71pZjFTYXZWM3FlRnM", - sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", - file_name="img_align_celeba.zip", - ) - identities = GDriveResource( - "1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", - sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0", - file_name="identity_CelebA.txt", - ) - attributes = GDriveResource( - "0B7EVK8r0v71pblRyaVFSWGxPY0U", - sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", - file_name="list_attr_celeba.txt", - ) - bounding_boxes = GDriveResource( - "0B7EVK8r0v71pbThiMVRxWXZ4dU0", - sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", - file_name="list_bbox_celeba.txt", - ) - landmarks = GDriveResource( - "0B7EVK8r0v71pd0FJY3Blby1HUTQ", - sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", - file_name="list_landmarks_align_celeba.txt", - ) - return [splits, images, identities, attributes, bounding_boxes, landmarks] - - def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: - split_id = { - "train": "0", - "val": "1", - "test": "2", - }[self._split] - return data[1]["split_id"] == split_id - - def _prepare_sample( - self, - data: Tuple[ - Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], - Tuple[ - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - ], - ], - ) -> Dict[str, Any]: - split_and_image_data, ann_data = data - _, (_, image_data) = split_and_image_data - path, buffer = image_data - - image = EncodedImage.from_file(buffer) - (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data - - return dict( - path=path, - image=image, - identity=Label(int(identity["identity"])), - attributes={attr: value == "1" for attr, value in attributes.items()}, - bounding_box=BoundingBox( - [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], - format="xywh", - image_size=image.image_size, - ), - landmarks={ - landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) - for landmark in {key[:-2] for key in landmarks.keys()} - }, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps - - splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, self._filter_split) - splits_dp = hint_shuffling(splits_dp) - splits_dp = hint_sharding(splits_dp) - - anns_dp = Zipper( - *[ - CelebACSVParser(dp, fieldnames=fieldnames) - for dp, fieldnames in ( - (identities_dp, ("image_id", "identity")), - (attributes_dp, None), - (bounding_boxes_dp, None), - (landmarks_dp, None), - ) - ] - ) - - dp = IterKeyZipper( - splits_dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - keep_key=True, - ) - dp = IterKeyZipper( - dp, - anns_dp, - key_fn=getitem(0), - ref_key_fn=getitem(0, 0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 162_770, - "val": 19_867, - "test": 19_962, - }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py deleted file mode 100644 index 26196ded638..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ /dev/null @@ -1,139 +0,0 @@ -import abc -import io -import pathlib -import pickle -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union - -import numpy as np -from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - path_comparator, - read_categories_file, -) -from torchvision.prototype.features import Image, Label - -from .._api import register_dataset, register_info - - -class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): - def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: - self.datapipe = datapipe - self.labels_key = labels_key - - def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: - for mapping in self.datapipe: - image_arrays = mapping["data"].reshape((-1, 3, 32, 32)) - category_idcs = mapping[self.labels_key] - yield from iter(zip(image_arrays, category_idcs)) - - -class _CifarBase(Dataset): - _FILE_NAME: str - _SHA256: str - _LABELS_KEY: str - _META_FILE_NAME: str - _CATEGORIES_KEY: str - _categories: List[str] - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - pass - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", - sha256=self._SHA256, - ) - ] - - def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: - _, file = data - return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - - def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: - image_array, category_idx = data - return dict( - image=Image(image_array), - label=Label(category_idx, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, self._is_data_file) - dp = Mapper(dp, self._unpickle) - dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 50_000 if self._split == "train" else 10_000 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) - dp = Mapper(dp, self._unpickle) - - return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) - - -@register_info("cifar10") -def _cifar10_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("cifar10")) - - -@register_dataset("cifar10") -class Cifar10(_CifarBase): - """ - - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html - """ - - _FILE_NAME = "cifar-10-python.tar.gz" - _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" - _LABELS_KEY = "labels" - _META_FILE_NAME = "batches.meta" - _CATEGORIES_KEY = "label_names" - _categories = _cifar10_info()["categories"] - - def _is_data_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name.startswith("data" if self._split == "train" else "test") - - -@register_info("cifar100") -def _cifar100_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("cifar100")) - - -@register_dataset("cifar100") -class Cifar100(_CifarBase): - """ - - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html - """ - - _FILE_NAME = "cifar-100-python.tar.gz" - _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" - _LABELS_KEY = "fine_labels" - _META_FILE_NAME = "meta" - _CATEGORIES_KEY = "fine_label_names" - _categories = _cifar100_info()["categories"] - - def _is_data_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name == self._split diff --git a/torchvision/prototype/datasets/_builtin/cifar10.categories b/torchvision/prototype/datasets/_builtin/cifar10.categories deleted file mode 100644 index fa30c22b95d..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar10.categories +++ /dev/null @@ -1,10 +0,0 @@ -airplane -automobile -bird -cat -deer -dog -frog -horse -ship -truck diff --git a/torchvision/prototype/datasets/_builtin/cifar100.categories b/torchvision/prototype/datasets/_builtin/cifar100.categories deleted file mode 100644 index 7f7bf51d1ab..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar100.categories +++ /dev/null @@ -1,100 +0,0 @@ -apple -aquarium_fish -baby -bear -beaver -bed -bee -beetle -bicycle -bottle -bowl -boy -bridge -bus -butterfly -camel -can -castle -caterpillar -cattle -chair -chimpanzee -clock -cloud -cockroach -couch -crab -crocodile -cup -dinosaur -dolphin -elephant -flatfish -forest -fox -girl -hamster -house -kangaroo -keyboard -lamp -lawn_mower -leopard -lion -lizard -lobster -man -maple_tree -motorcycle -mountain -mouse -mushroom -oak_tree -orange -orchid -otter -palm_tree -pear -pickup_truck -pine_tree -plain -plate -poppy -porcupine -possum -rabbit -raccoon -ray -road -rocket -rose -sea -seal -shark -shrew -skunk -skyscraper -snail -snake -spider -squirrel -streetcar -sunflower -sweet_pepper -table -tank -telephone -television -tiger -tractor -train -trout -tulip -turtle -wardrobe -whale -willow_tree -wolf -woman -worm diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py deleted file mode 100644 index 4ddacdfb982..00000000000 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ /dev/null @@ -1,105 +0,0 @@ -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - -NAME = "clevr" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict() - - -@register_dataset(NAME) -class CLEVR(Dataset): - """ - - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - archive = HttpResource( - "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", - sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", - ) - return [archive] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.parent.name == "scenes": - return 1 - else: - return None - - def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: - key, _ = data - return key == "scenes" - - def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]: - return data, None - - def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]: - image_data, scenes_data = data - path, buffer = image_data - - return dict( - path=path, - image=EncodedImage.from_file(buffer), - label=Label(len(scenes_data["objects"])) if scenes_data else None, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, scenes_dp = Demultiplexer( - archive_dp, - 2, - self._classify_archive, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - images_dp = Filter(images_dp, path_comparator("parent.name", self._split)) - images_dp = hint_shuffling(images_dp) - images_dp = hint_sharding(images_dp) - - if self._split != "test": - scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json")) - scenes_dp = JsonParser(scenes_dp) - scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) - scenes_dp = UnBatcher(scenes_dp) - - dp = IterKeyZipper( - images_dp, - scenes_dp, - key_fn=path_accessor("name"), - ref_key_fn=getitem("image_filename"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - else: - dp = Mapper(images_dp, self._add_empty_anns) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 70_000 if self._split == "train" else 15_000 diff --git a/torchvision/prototype/datasets/_builtin/coco.categories b/torchvision/prototype/datasets/_builtin/coco.categories deleted file mode 100644 index 27e612f6d7d..00000000000 --- a/torchvision/prototype/datasets/_builtin/coco.categories +++ /dev/null @@ -1,91 +0,0 @@ -__background__,N/A -person,person -bicycle,vehicle -car,vehicle -motorcycle,vehicle -airplane,vehicle -bus,vehicle -train,vehicle -truck,vehicle -boat,vehicle -traffic light,outdoor -fire hydrant,outdoor -N/A,N/A -stop sign,outdoor -parking meter,outdoor -bench,outdoor -bird,animal -cat,animal -dog,animal -horse,animal -sheep,animal -cow,animal -elephant,animal -bear,animal -zebra,animal -giraffe,animal -N/A,N/A -backpack,accessory -umbrella,accessory -N/A,N/A -N/A,N/A -handbag,accessory -tie,accessory -suitcase,accessory -frisbee,sports -skis,sports -snowboard,sports -sports ball,sports -kite,sports -baseball bat,sports -baseball glove,sports -skateboard,sports -surfboard,sports -tennis racket,sports -bottle,kitchen -N/A,N/A -wine glass,kitchen -cup,kitchen -fork,kitchen -knife,kitchen -spoon,kitchen -bowl,kitchen -banana,food -apple,food -sandwich,food -orange,food -broccoli,food -carrot,food -hot dog,food -pizza,food -donut,food -cake,food -chair,furniture -couch,furniture -potted plant,furniture -bed,furniture -N/A,N/A -dining table,furniture -N/A,N/A -N/A,N/A -toilet,furniture -N/A,N/A -tv,electronic -laptop,electronic -mouse,electronic -remote,electronic -keyboard,electronic -cell phone,electronic -microwave,appliance -oven,appliance -toaster,appliance -sink,appliance -refrigerator,appliance -N/A,N/A -book,indoor -clock,indoor -vase,indoor -scissors,indoor -teddy bear,indoor -hair drier,indoor -toothbrush,indoor diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py deleted file mode 100644 index 16a16998bf7..00000000000 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ /dev/null @@ -1,270 +0,0 @@ -import pathlib -import re -from collections import defaultdict, OrderedDict -from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union - -import torch -from torchdata.datapipes.iter import ( - Demultiplexer, - Filter, - Grouper, - IterDataPipe, - IterKeyZipper, - JsonParser, - Mapper, - UnBatcher, -) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - MappingIterator, - path_accessor, - read_categories_file, -) -from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label - -from .._api import register_dataset, register_info - - -NAME = "coco" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - categories, super_categories = zip(*read_categories_file(NAME)) - return dict(categories=categories, super_categories=super_categories) - - -@register_dataset(NAME) -class Coco(Dataset): - """ - - **homepage**: https://cocodataset.org/ - - **dependencies**: - - _ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - year: str = "2017", - annotations: Optional[str] = "instances", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val"}) - self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) - self._annotations = ( - self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) - if annotations is not None - else None - ) - - info = _info() - categories, super_categories = info["categories"], info["super_categories"] - self._categories = categories - self._category_to_super_category = dict(zip(categories, super_categories)) - - super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) - - _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" - - _IMAGES_CHECKSUMS = { - ("2014", "train"): "ede4087e640bddba550e090eae701092534b554b42b05ac33f0300b984b31775", - ("2014", "val"): "fe9be816052049c34717e077d9e34aa60814a55679f804cd043e3cbee3b9fde0", - ("2017", "train"): "69a8bb58ea5f8f99d24875f21416de2e9ded3178e903f1f7603e283b9e06d929", - ("2017", "val"): "4f7e2ccb2866ec5041993c9cf2a952bbed69647b115d0f74da7ce8f4bef82f05", - } - - _META_URL_BASE = "http://images.cocodataset.org/annotations" - - _META_CHECKSUMS = { - "2014": "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009", - "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", - } - - def _resources(self) -> List[OnlineResource]: - images = HttpResource( - f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", - sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], - ) - meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", - sha256=self._META_CHECKSUMS[self._year], - ) - return [images, meta] - - def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor: - from pycocotools import mask - - if is_crowd: - segmentation = mask.frPyObjects(segmentation, *image_size) - else: - segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size)) - - return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) - - def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: - image_size = (image_meta["height"], image_meta["width"]) - labels = [ann["category_id"] for ann in anns] - return dict( - # TODO: create a segmentation feature - segmentations=_Feature( - torch.stack( - [ - self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) - for ann in anns - ] - ) - ), - areas=_Feature([ann["area"] for ann in anns]), - crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), - bounding_boxes=BoundingBox( - [ann["bbox"] for ann in anns], - format="xywh", - image_size=image_size, - ), - labels=Label(labels, categories=self._categories), - super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], - ann_ids=[ann["id"] for ann in anns], - ) - - def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: - return dict( - captions=[ann["caption"] for ann in anns], - ann_ids=[ann["id"] for ann in anns], - ) - - _ANN_DECODERS = OrderedDict( - [ - ("instances", _decode_instances_anns), - ("captions", _decode_captions_ann), - ] - ) - - _META_FILE_PATTERN = re.compile( - rf"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" - ) - - def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: - match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool( - match - and match["split"] == self._split - and match["year"] == self._year - and match["annotations"] == self._annotations - ) - - def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: - key, _ = data - if key == "images": - return 0 - elif key == "annotations": - return 1 - else: - return None - - def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: - path, buffer = data - return dict( - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _prepare_sample( - self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - ann_data, image_data = data - anns, image_meta = ann_data - - sample = self._prepare_image(image_data) - # this method is only called if we have annotations - annotations = cast(str, self._annotations) - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - return sample - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, meta_dp = resource_dps - - if self._annotations is None: - dp = hint_shuffling(images_dp) - dp = hint_sharding(dp) - dp = hint_shuffling(dp) - return Mapper(dp, self._prepare_image) - - meta_dp = Filter(meta_dp, self._filter_meta_files) - meta_dp = JsonParser(meta_dp) - meta_dp = Mapper(meta_dp, getitem(1)) - meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) - images_meta_dp, anns_meta_dp = Demultiplexer( - meta_dp, - 2, - self._classify_meta, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - images_meta_dp = Mapper(images_meta_dp, getitem(1)) - images_meta_dp = UnBatcher(images_meta_dp) - - anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) - anns_meta_dp = UnBatcher(anns_meta_dp) - anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) - anns_meta_dp = hint_shuffling(anns_meta_dp) - anns_meta_dp = hint_sharding(anns_meta_dp) - - anns_dp = IterKeyZipper( - anns_meta_dp, - images_meta_dp, - key_fn=getitem(0, "image_id"), - ref_key_fn=getitem("id"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - anns_dp, - images_dp, - key_fn=getitem(1, "file_name"), - ref_key_fn=path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), - ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), - ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), - ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), - }[(self._split, self._year)][ - self._annotations # type: ignore[index] - ] - - def _generate_categories(self) -> Tuple[Tuple[str, str]]: - self._annotations = "instances" - resources = self._resources() - - dp = resources[1].load(self._root) - dp = Filter(dp, self._filter_meta_files) - dp = JsonParser(dp) - - _, meta = next(iter(dp)) - # List[Tuple[super_category, id, category]] - label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]] - - # COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the - # full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total - # number of categories is 90 rather than 91. - _, ids, _ = zip(*label_data) - missing_ids = set(range(1, max(ids) + 1)) - set(ids) - label_data.extend([("N/A", id, "N/A") for id in missing_ids]) - - # We also add a background category to be used during segmentation. - label_data.append(("N/A", 0, "__background__")) - - super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1])) - - return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories))) diff --git a/torchvision/prototype/datasets/_builtin/country211.categories b/torchvision/prototype/datasets/_builtin/country211.categories deleted file mode 100644 index 6fc3e99a185..00000000000 --- a/torchvision/prototype/datasets/_builtin/country211.categories +++ /dev/null @@ -1,211 +0,0 @@ -AD -AE -AF -AG -AI -AL -AM -AO -AQ -AR -AT -AU -AW -AX -AZ -BA -BB -BD -BE -BF -BG -BH -BJ -BM -BN -BO -BQ -BR -BS -BT -BW -BY -BZ -CA -CD -CF -CH -CI -CK -CL -CM -CN -CO -CR -CU -CV -CW -CY -CZ -DE -DK -DM -DO -DZ -EC -EE -EG -ES -ET -FI -FJ -FK -FO -FR -GA -GB -GD -GE -GF -GG -GH -GI -GL -GM -GP -GR -GS -GT -GU -GY -HK -HN -HR -HT -HU -ID -IE -IL -IM -IN -IQ -IR -IS -IT -JE -JM -JO -JP -KE -KG -KH -KN -KP -KR -KW -KY -KZ -LA -LB -LC -LI -LK -LR -LT -LU -LV -LY -MA -MC -MD -ME -MF -MG -MK -ML -MM -MN -MO -MQ -MR -MT -MU -MV -MW -MX -MY -MZ -NA -NC -NG -NI -NL -NO -NP -NZ -OM -PA -PE -PF -PG -PH -PK -PL -PR -PS -PT -PW -PY -QA -RE -RO -RS -RU -RW -SA -SB -SC -SD -SE -SG -SH -SI -SJ -SK -SL -SM -SN -SO -SS -SV -SX -SY -SZ -TG -TH -TJ -TL -TM -TN -TO -TR -TT -TW -TZ -UA -UG -US -UY -UZ -VA -VE -VG -VI -VN -VU -WS -XK -YE -ZA -ZM -ZW diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py deleted file mode 100644 index f9821ea4eb6..00000000000 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ /dev/null @@ -1,81 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - path_comparator, - read_categories_file, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - -NAME = "country211" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class Country211(Dataset): - """ - - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - self._split_folder_name = "valid" if split == "val" else split - - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - "https://openaipublic.azureedge.net/clip/data/country211.tgz", - sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c", - ) - ] - - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: - path, buffer = data - category = pathlib.Path(path).parent.name - return dict( - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: - return pathlib.Path(data[0]).parent.parent.name == split - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 31_650, - "val": 10_550, - "test": 21_100, - }[self._split] - - def _generate_categories(self) -> List[str]: - resources = self._resources() - dp = resources[0].load(self._root) - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) diff --git a/torchvision/prototype/datasets/_builtin/cub200.categories b/torchvision/prototype/datasets/_builtin/cub200.categories deleted file mode 100644 index f91754c930c..00000000000 --- a/torchvision/prototype/datasets/_builtin/cub200.categories +++ /dev/null @@ -1,200 +0,0 @@ -Black_footed_Albatross -Laysan_Albatross -Sooty_Albatross -Groove_billed_Ani -Crested_Auklet -Least_Auklet -Parakeet_Auklet -Rhinoceros_Auklet -Brewer_Blackbird -Red_winged_Blackbird -Rusty_Blackbird -Yellow_headed_Blackbird -Bobolink -Indigo_Bunting -Lazuli_Bunting -Painted_Bunting -Cardinal -Spotted_Catbird -Gray_Catbird -Yellow_breasted_Chat -Eastern_Towhee -Chuck_will_Widow -Brandt_Cormorant -Red_faced_Cormorant -Pelagic_Cormorant -Bronzed_Cowbird -Shiny_Cowbird -Brown_Creeper -American_Crow -Fish_Crow -Black_billed_Cuckoo -Mangrove_Cuckoo -Yellow_billed_Cuckoo -Gray_crowned_Rosy_Finch -Purple_Finch -Northern_Flicker -Acadian_Flycatcher -Great_Crested_Flycatcher -Least_Flycatcher -Olive_sided_Flycatcher -Scissor_tailed_Flycatcher -Vermilion_Flycatcher -Yellow_bellied_Flycatcher -Frigatebird -Northern_Fulmar -Gadwall -American_Goldfinch -European_Goldfinch -Boat_tailed_Grackle -Eared_Grebe -Horned_Grebe -Pied_billed_Grebe -Western_Grebe -Blue_Grosbeak -Evening_Grosbeak -Pine_Grosbeak -Rose_breasted_Grosbeak -Pigeon_Guillemot -California_Gull -Glaucous_winged_Gull -Heermann_Gull -Herring_Gull -Ivory_Gull -Ring_billed_Gull -Slaty_backed_Gull -Western_Gull -Anna_Hummingbird -Ruby_throated_Hummingbird -Rufous_Hummingbird -Green_Violetear -Long_tailed_Jaeger -Pomarine_Jaeger -Blue_Jay -Florida_Jay -Green_Jay -Dark_eyed_Junco -Tropical_Kingbird -Gray_Kingbird -Belted_Kingfisher -Green_Kingfisher -Pied_Kingfisher -Ringed_Kingfisher -White_breasted_Kingfisher -Red_legged_Kittiwake -Horned_Lark -Pacific_Loon -Mallard -Western_Meadowlark -Hooded_Merganser -Red_breasted_Merganser -Mockingbird -Nighthawk -Clark_Nutcracker -White_breasted_Nuthatch -Baltimore_Oriole -Hooded_Oriole -Orchard_Oriole -Scott_Oriole -Ovenbird -Brown_Pelican -White_Pelican -Western_Wood_Pewee -Sayornis -American_Pipit -Whip_poor_Will -Horned_Puffin -Common_Raven -White_necked_Raven -American_Redstart -Geococcyx -Loggerhead_Shrike -Great_Grey_Shrike -Baird_Sparrow -Black_throated_Sparrow -Brewer_Sparrow -Chipping_Sparrow -Clay_colored_Sparrow -House_Sparrow -Field_Sparrow -Fox_Sparrow -Grasshopper_Sparrow -Harris_Sparrow -Henslow_Sparrow -Le_Conte_Sparrow -Lincoln_Sparrow -Nelson_Sharp_tailed_Sparrow -Savannah_Sparrow -Seaside_Sparrow -Song_Sparrow -Tree_Sparrow -Vesper_Sparrow -White_crowned_Sparrow -White_throated_Sparrow -Cape_Glossy_Starling -Bank_Swallow -Barn_Swallow -Cliff_Swallow -Tree_Swallow -Scarlet_Tanager -Summer_Tanager -Artic_Tern -Black_Tern -Caspian_Tern -Common_Tern -Elegant_Tern -Forsters_Tern -Least_Tern -Green_tailed_Towhee -Brown_Thrasher -Sage_Thrasher -Black_capped_Vireo -Blue_headed_Vireo -Philadelphia_Vireo -Red_eyed_Vireo -Warbling_Vireo -White_eyed_Vireo -Yellow_throated_Vireo -Bay_breasted_Warbler -Black_and_white_Warbler -Black_throated_Blue_Warbler -Blue_winged_Warbler -Canada_Warbler -Cape_May_Warbler -Cerulean_Warbler -Chestnut_sided_Warbler -Golden_winged_Warbler -Hooded_Warbler -Kentucky_Warbler -Magnolia_Warbler -Mourning_Warbler -Myrtle_Warbler -Nashville_Warbler -Orange_crowned_Warbler -Palm_Warbler -Pine_Warbler -Prairie_Warbler -Prothonotary_Warbler -Swainson_Warbler -Tennessee_Warbler -Wilson_Warbler -Worm_eating_Warbler -Yellow_Warbler -Northern_Waterthrush -Louisiana_Waterthrush -Bohemian_Waxwing -Cedar_Waxwing -American_Three_toed_Woodpecker -Pileated_Woodpecker -Red_bellied_Woodpecker -Red_cockaded_Woodpecker -Red_headed_Woodpecker -Downy_Woodpecker -Bewick_Wren -Cactus_Wren -Carolina_Wren -House_Wren -Marsh_Wren -Rock_Wren -Winter_Wren -Common_Yellowthroat diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py deleted file mode 100644 index c07166a960c..00000000000 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ /dev/null @@ -1,258 +0,0 @@ -import csv -import functools -import pathlib -from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import ( - CSVDictParser, - CSVParser, - Demultiplexer, - Filter, - IterDataPipe, - IterKeyZipper, - LineReader, - Mapper, -) -from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, - read_mat, -) -from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label - -from .._api import register_dataset, register_info - -csv.register_dialect("cub200", delimiter=" ") - - -NAME = "cub200" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class CUB200(Dataset): - """ - - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - year: str = "2011", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - self._year = self._verify_str_arg(year, "year", ("2010", "2011")) - - self._categories = _info()["categories"] - - super().__init__( - root, - # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 - # dependencies=("scipy",), - skip_integrity_check=skip_integrity_check, - ) - - def _resources(self) -> List[OnlineResource]: - if self._year == "2011": - archive = GDriveResource( - "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45", - file_name="CUB_200_2011.tgz", - sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", - preprocess="decompress", - ) - segmentations = GDriveResource( - "1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP", - file_name="segmentations.tgz", - sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f", - preprocess="decompress", - ) - return [archive, segmentations] - else: # self._year == "2010" - split = GDriveResource( - "1vZuZPqha0JjmwkdaS_XtYryE3Jf5Q1AC", - file_name="lists.tgz", - sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", - preprocess="decompress", - ) - images = GDriveResource( - "1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx", - file_name="images.tgz", - sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e", - preprocess="decompress", - ) - anns = GDriveResource( - "16NsbTpMs5L6hT4hUJAmpW2u7wH326WTR", - file_name="annotations.tgz", - sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1", - preprocess="decompress", - ) - return [split, images, anns] - - def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.name == "train_test_split.txt": - return 1 - elif path.name == "images.txt": - return 2 - elif path.name == "bounding_boxes.txt": - return 3 - else: - return None - - def _2011_extract_file_name(self, rel_posix_path: str) -> str: - return rel_posix_path.rsplit("/", maxsplit=1)[1] - - def _2011_filter_split(self, row: List[str]) -> bool: - _, split_id = row - return { - "0": "test", - "1": "train", - }[split_id] == self._split - - def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: - path = pathlib.Path(data[0]) - return path.with_suffix(".jpg").name - - def _2011_prepare_ann( - self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int] - ) -> Dict[str, Any]: - _, (bounding_box_data, segmentation_data) = data - segmentation_path, segmentation_buffer = segmentation_data - return dict( - bounding_box=BoundingBox( - [float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size - ), - segmentation_path=segmentation_path, - segmentation=EncodedImage.from_file(segmentation_buffer), - ) - - def _2010_split_key(self, data: str) -> str: - return data.rsplit("/", maxsplit=1)[1] - - def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - return path.with_suffix(".jpg").name, data - - def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]: - _, (path, buffer) = data - content = read_mat(buffer) - return dict( - ann_path=path, - bounding_box=BoundingBox( - [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], - format="xyxy", - image_size=image_size, - ), - segmentation=_Feature(content["seg"]), - ) - - def _prepare_sample( - self, - data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any], - *, - prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]], - ) -> Dict[str, Any]: - data, anns_data = data - _, image_data = data - path, buffer = image_data - - image = EncodedImage.from_file(buffer) - - return dict( - prepare_ann_fn(anns_data, image.image_size), - image=image, - label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - prepare_ann_fn: Callable - if self._year == "2011": - archive_dp, segmentations_dp = resource_dps - images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( - archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - image_files_dp = CSVParser(image_files_dp, dialect="cub200") - image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1) - image_files_map = IterToMapConverter(image_files_dp) - - split_dp = CSVParser(split_dp, dialect="cub200") - split_dp = Filter(split_dp, self._2011_filter_split) - split_dp = Mapper(split_dp, getitem(0)) - split_dp = Mapper(split_dp, image_files_map.__getitem__) - - bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200") - bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0) - - anns_dp = IterKeyZipper( - bounding_boxes_dp, - segmentations_dp, - key_fn=getitem(0), - ref_key_fn=self._2011_segmentation_key, - keep_key=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - prepare_ann_fn = self._2011_prepare_ann - else: # self._year == "2010" - split_dp, images_dp, anns_dp = resource_dps - - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True, return_path=False) - split_dp = Mapper(split_dp, self._2010_split_key) - - anns_dp = Mapper(anns_dp, self._2010_anns_key) - - prepare_ann_fn = self._2010_prepare_ann - - split_dp = hint_shuffling(split_dp) - split_dp = hint_sharding(split_dp) - - dp = IterKeyZipper( - split_dp, - images_dp, - getitem(), - path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - dp, - anns_dp, - getitem(0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) - - def __len__(self) -> int: - return { - ("train", "2010"): 3_000, - ("test", "2010"): 3_033, - ("train", "2011"): 5_994, - ("test", "2011"): 5_794, - }[(self._split, self._year)] - - def _generate_categories(self) -> List[str]: - self._year = "2011" - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "classes.txt")) - dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200") - - return [row["category"].split(".")[1] for row in dp] diff --git a/torchvision/prototype/datasets/_builtin/dtd.categories b/torchvision/prototype/datasets/_builtin/dtd.categories deleted file mode 100644 index 7f3df8a2b00..00000000000 --- a/torchvision/prototype/datasets/_builtin/dtd.categories +++ /dev/null @@ -1,47 +0,0 @@ -banded -blotchy -braided -bubbly -bumpy -chequered -cobwebbed -cracked -crosshatched -crystalline -dotted -fibrous -flecked -freckled -frilly -gauzy -grid -grooved -honeycombed -interlaced -knitted -lacelike -lined -marbled -matted -meshed -paisley -perforated -pitted -pleated -polka-dotted -porous -potholed -scaly -smeared -spiralled -sprinkled -stained -stratified -striped -studded -swirly -veined -waffled -woven -wrinkled -zigzagged diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py deleted file mode 100644 index e7ff1e79559..00000000000 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ /dev/null @@ -1,139 +0,0 @@ -import enum -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, - read_categories_file, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - - -NAME = "dtd" - - -class DTDDemux(enum.IntEnum): - SPLIT = 0 - JOINT_CATEGORIES = 1 - IMAGES = 2 - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class DTD(Dataset): - """DTD Dataset. - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - fold: int = 1, - skip_validation_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - - if not (1 <= fold <= 10): - raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}") - self._fold = fold - - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_validation_check) - - def _resources(self) -> List[OnlineResource]: - archive = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", - sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", - preprocess="decompress", - ) - return [archive] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parent.name == "labels": - if path.name == "labels_joint_anno.txt": - return DTDDemux.JOINT_CATEGORIES - - return DTDDemux.SPLIT - elif path.parents[1].name == "images": - return DTDDemux.IMAGES - else: - return None - - def _image_key_fn(self, data: Tuple[str, Any]) -> str: - path = pathlib.Path(data[0]) - # The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg - return str(path.relative_to(path.parents[1]).as_posix()) - - def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: - (_, joint_categories_data), image_data = data - _, *joint_categories = joint_categories_data - path, buffer = image_data - - category = pathlib.Path(path).parent.name - - return dict( - joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - - splits_dp, joint_categories_dp, images_dp = Demultiplexer( - archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt")) - splits_dp = LineReader(splits_dp, decode=True, return_path=False) - splits_dp = hint_shuffling(splits_dp) - splits_dp = hint_sharding(splits_dp) - - joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") - - dp = IterKeyZipper( - splits_dp, - joint_categories_dp, - key_fn=getitem(), - ref_key_fn=getitem(0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=self._image_key_fn, - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def _filter_images(self, data: Tuple[str, Any]) -> bool: - return self._classify_archive(data) == DTDDemux.IMAGES - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, self._filter_images) - - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) - - def __len__(self) -> int: - return 1_880 # All splits have the same length diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py deleted file mode 100644 index 88863dbcb3a..00000000000 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ /dev/null @@ -1,66 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - -NAME = "eurosat" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - categories=( - "AnnualCrop", - "Forest", - "HerbaceousVegetation", - "Highway", - "Industrial", - "Pasture", - "PermanentCrop", - "Residential", - "River", - "SeaLake", - ) - ) - - -@register_dataset(NAME) -class EuroSAT(Dataset): - """EuroSAT Dataset. - homepage="https://github.com/phelber/eurosat", - """ - - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - "https://madm.dfki.de/files/sentinel/EuroSAT.zip", - sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd", - ) - ] - - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: - path, buffer = data - category = pathlib.Path(path).parent.name - return dict( - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 27_000 diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py deleted file mode 100644 index b2693aa96c0..00000000000 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ /dev/null @@ -1,63 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Union - -import torch -from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, Label - -from .._api import register_dataset, register_info - -NAME = "fer2013" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral")) - - -@register_dataset(NAME) -class FER2013(Dataset): - """FER 2013 Dataset - homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _CHECKSUMS = { - "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", - "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", - } - - def _resources(self) -> List[OnlineResource]: - archive = KaggleDownloadResource( - "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", - file_name=f"{self._split}.csv.zip", - sha256=self._CHECKSUMS[self._split], - ) - return [archive] - - def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: - label_id = data.get("emotion") - - return dict( - image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self._categories) if label_id is not None else None, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = CSVDictParser(dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 28_709 if self._split == "train" else 3_589 diff --git a/torchvision/prototype/datasets/_builtin/food101.categories b/torchvision/prototype/datasets/_builtin/food101.categories deleted file mode 100644 index 59f252ddff4..00000000000 --- a/torchvision/prototype/datasets/_builtin/food101.categories +++ /dev/null @@ -1,101 +0,0 @@ -apple_pie -baby_back_ribs -baklava -beef_carpaccio -beef_tartare -beet_salad -beignets -bibimbap -bread_pudding -breakfast_burrito -bruschetta -caesar_salad -cannoli -caprese_salad -carrot_cake -ceviche -cheesecake -cheese_plate -chicken_curry -chicken_quesadilla -chicken_wings -chocolate_cake -chocolate_mousse -churros -clam_chowder -club_sandwich -crab_cakes -creme_brulee -croque_madame -cup_cakes -deviled_eggs -donuts -dumplings -edamame -eggs_benedict -escargots -falafel -filet_mignon -fish_and_chips -foie_gras -french_fries -french_onion_soup -french_toast -fried_calamari -fried_rice -frozen_yogurt -garlic_bread -gnocchi -greek_salad -grilled_cheese_sandwich -grilled_salmon -guacamole -gyoza -hamburger -hot_and_sour_soup -hot_dog -huevos_rancheros -hummus -ice_cream -lasagna -lobster_bisque -lobster_roll_sandwich -macaroni_and_cheese -macarons -miso_soup -mussels -nachos -omelette -onion_rings -oysters -pad_thai -paella -pancakes -panna_cotta -peking_duck -pho -pizza -pork_chop -poutine -prime_rib -pulled_pork_sandwich -ramen -ravioli -red_velvet_cake -risotto -samosa -sashimi -scallops -seaweed_salad -shrimp_and_grits -spaghetti_bolognese -spaghetti_carbonara -spring_rolls -steak -strawberry_shortcake -sushi -tacos -takoyaki -tiramisu -tuna_tartare -waffles diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py deleted file mode 100644 index 3657116ae7a..00000000000 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ /dev/null @@ -1,97 +0,0 @@ -from pathlib import Path -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, - read_categories_file, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - - -NAME = "food101" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class Food101(Dataset): - """Food 101 dataset - homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", - """ - - def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", - sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4", - preprocess="decompress", - ) - ] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.parents[0].name == "meta": - return 1 - else: - return None - - def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: - id, (path, buffer) = data - return dict( - label=Label.from_category(id.split("/", 1)[0], categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _image_key(self, data: Tuple[str, Any]) -> str: - path = Path(data[0]) - return path.relative_to(path.parents[1]).with_suffix("").as_posix() - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, split_dp = Demultiplexer( - archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True, return_path=False) - split_dp = hint_sharding(split_dp) - split_dp = hint_shuffling(split_dp) - - dp = IterKeyZipper( - split_dp, - images_dp, - key_fn=getitem(), - ref_key_fn=self._image_key, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - return Mapper(dp, self._prepare_sample) - - def _generate_categories(self) -> List[str]: - resources = self._resources() - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "classes.txt")) - dp = LineReader(dp, decode=True, return_path=False) - return list(dp) - - def __len__(self) -> int: - return 75_750 if self._split == "train" else 25_250 diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py deleted file mode 100644 index 8dc0a8240c8..00000000000 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ /dev/null @@ -1,111 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, -) -from torchvision.prototype.features import BoundingBox, EncodedImage, Label - -from .._api import register_dataset, register_info - -NAME = "gtsrb" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - categories=[f"{label:05d}" for label in range(43)], - ) - - -@register_dataset(NAME) -class GTSRB(Dataset): - """GTSRB Dataset - - homepage="https://benchmark.ini.rub.de" - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" - _URLS = { - "train": f"{_URL_ROOT}GTSRB-Training_fixed.zip", - "test": f"{_URL_ROOT}GTSRB_Final_Test_Images.zip", - "test_ground_truth": f"{_URL_ROOT}GTSRB_Final_Test_GT.zip", - } - _CHECKSUMS = { - "train": "df4144942083645bd60b594de348aa6930126c3e0e5de09e39611630abf8455a", - "test": "48ba6fab7e877eb64eaf8de99035b0aaecfbc279bee23e35deca4ac1d0a837fa", - "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", - } - - def _resources(self) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - - if self._split == "test": - rsrcs.append( - HttpResource( - self._URLS["test_ground_truth"], - sha256=self._CHECKSUMS["test_ground_truth"], - ) - ) - - return rsrcs - - def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.suffix == ".ppm": - return 0 - elif path.suffix == ".csv": - return 1 - else: - return None - - def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]: - (path, buffer), csv_info = data - label = int(csv_info["ClassId"]) - - bounding_box = BoundingBox( - [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], - format="xyxy", - image_size=(int(csv_info["Height"]), int(csv_info["Width"])), - ) - - return { - "path": path, - "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self._categories), - "bounding_box": bounding_box, - } - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split == "train": - images_dp, ann_dp = Demultiplexer( - resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - else: - images_dp, ann_dp = resource_dps - images_dp = Filter(images_dp, path_comparator("suffix", ".ppm")) - - # The order of the image files in the .zip archives perfectly match the order of the entries in the - # (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. - ann_dp = CSVDictParser(ann_dp, delimiter=";") - dp = Zipper(images_dp, ann_dp) - - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 26_640 if self._split == "train" else 12_630 diff --git a/torchvision/prototype/datasets/_builtin/imagenet.categories b/torchvision/prototype/datasets/_builtin/imagenet.categories deleted file mode 100644 index 7b6006ff57f..00000000000 --- a/torchvision/prototype/datasets/_builtin/imagenet.categories +++ /dev/null @@ -1,1000 +0,0 @@ -tench,n01440764 -goldfish,n01443537 -great white shark,n01484850 -tiger shark,n01491361 -hammerhead,n01494475 -electric ray,n01496331 -stingray,n01498041 -cock,n01514668 -hen,n01514859 -ostrich,n01518878 -brambling,n01530575 -goldfinch,n01531178 -house finch,n01532829 -junco,n01534433 -indigo bunting,n01537544 -robin,n01558993 -bulbul,n01560419 -jay,n01580077 -magpie,n01582220 -chickadee,n01592084 -water ouzel,n01601694 -kite,n01608432 -bald eagle,n01614925 -vulture,n01616318 -great grey owl,n01622779 -European fire salamander,n01629819 -common newt,n01630670 -eft,n01631663 -spotted salamander,n01632458 -axolotl,n01632777 -bullfrog,n01641577 -tree frog,n01644373 -tailed frog,n01644900 -loggerhead,n01664065 -leatherback turtle,n01665541 -mud turtle,n01667114 -terrapin,n01667778 -box turtle,n01669191 -banded gecko,n01675722 -common iguana,n01677366 -American chameleon,n01682714 -whiptail,n01685808 -agama,n01687978 -frilled lizard,n01688243 -alligator lizard,n01689811 -Gila monster,n01692333 -green lizard,n01693334 -African chameleon,n01694178 -Komodo dragon,n01695060 -African crocodile,n01697457 -American alligator,n01698640 -triceratops,n01704323 -thunder snake,n01728572 -ringneck snake,n01728920 -hognose snake,n01729322 -green snake,n01729977 -king snake,n01734418 -garter snake,n01735189 -water snake,n01737021 -vine snake,n01739381 -night snake,n01740131 -boa constrictor,n01742172 -rock python,n01744401 -Indian cobra,n01748264 -green mamba,n01749939 -sea snake,n01751748 -horned viper,n01753488 -diamondback,n01755581 -sidewinder,n01756291 -trilobite,n01768244 -harvestman,n01770081 -scorpion,n01770393 -black and gold garden spider,n01773157 -barn spider,n01773549 -garden spider,n01773797 -black widow,n01774384 -tarantula,n01774750 -wolf spider,n01775062 -tick,n01776313 -centipede,n01784675 -black grouse,n01795545 -ptarmigan,n01796340 -ruffed grouse,n01797886 -prairie chicken,n01798484 -peacock,n01806143 -quail,n01806567 -partridge,n01807496 -African grey,n01817953 -macaw,n01818515 -sulphur-crested cockatoo,n01819313 -lorikeet,n01820546 -coucal,n01824575 -bee eater,n01828970 -hornbill,n01829413 -hummingbird,n01833805 -jacamar,n01843065 -toucan,n01843383 -drake,n01847000 -red-breasted merganser,n01855032 -goose,n01855672 -black swan,n01860187 -tusker,n01871265 -echidna,n01872401 -platypus,n01873310 -wallaby,n01877812 -koala,n01882714 -wombat,n01883070 -jellyfish,n01910747 -sea anemone,n01914609 -brain coral,n01917289 -flatworm,n01924916 -nematode,n01930112 -conch,n01943899 -snail,n01944390 -slug,n01945685 -sea slug,n01950731 -chiton,n01955084 -chambered nautilus,n01968897 -Dungeness crab,n01978287 -rock crab,n01978455 -fiddler crab,n01980166 -king crab,n01981276 -American lobster,n01983481 -spiny lobster,n01984695 -crayfish,n01985128 -hermit crab,n01986214 -isopod,n01990800 -white stork,n02002556 -black stork,n02002724 -spoonbill,n02006656 -flamingo,n02007558 -little blue heron,n02009229 -American egret,n02009912 -bittern,n02011460 -crane,n02012849 -limpkin,n02013706 -European gallinule,n02017213 -American coot,n02018207 -bustard,n02018795 -ruddy turnstone,n02025239 -red-backed sandpiper,n02027492 -redshank,n02028035 -dowitcher,n02033041 -oystercatcher,n02037110 -pelican,n02051845 -king penguin,n02056570 -albatross,n02058221 -grey whale,n02066245 -killer whale,n02071294 -dugong,n02074367 -sea lion,n02077923 -Chihuahua,n02085620 -Japanese spaniel,n02085782 -Maltese dog,n02085936 -Pekinese,n02086079 -Shih-Tzu,n02086240 -Blenheim spaniel,n02086646 -papillon,n02086910 -toy terrier,n02087046 -Rhodesian ridgeback,n02087394 -Afghan hound,n02088094 -basset,n02088238 -beagle,n02088364 -bloodhound,n02088466 -bluetick,n02088632 -black-and-tan coonhound,n02089078 -Walker hound,n02089867 -English foxhound,n02089973 -redbone,n02090379 -borzoi,n02090622 -Irish wolfhound,n02090721 -Italian greyhound,n02091032 -whippet,n02091134 -Ibizan hound,n02091244 -Norwegian elkhound,n02091467 -otterhound,n02091635 -Saluki,n02091831 -Scottish deerhound,n02092002 -Weimaraner,n02092339 -Staffordshire bullterrier,n02093256 -American Staffordshire terrier,n02093428 -Bedlington terrier,n02093647 -Border terrier,n02093754 -Kerry blue terrier,n02093859 -Irish terrier,n02093991 -Norfolk terrier,n02094114 -Norwich terrier,n02094258 -Yorkshire terrier,n02094433 -wire-haired fox terrier,n02095314 -Lakeland terrier,n02095570 -Sealyham terrier,n02095889 -Airedale,n02096051 -cairn,n02096177 -Australian terrier,n02096294 -Dandie Dinmont,n02096437 -Boston bull,n02096585 -miniature schnauzer,n02097047 -giant schnauzer,n02097130 -standard schnauzer,n02097209 -Scotch terrier,n02097298 -Tibetan terrier,n02097474 -silky terrier,n02097658 -soft-coated wheaten terrier,n02098105 -West Highland white terrier,n02098286 -Lhasa,n02098413 -flat-coated retriever,n02099267 -curly-coated retriever,n02099429 -golden retriever,n02099601 -Labrador retriever,n02099712 -Chesapeake Bay retriever,n02099849 -German short-haired pointer,n02100236 -vizsla,n02100583 -English setter,n02100735 -Irish setter,n02100877 -Gordon setter,n02101006 -Brittany spaniel,n02101388 -clumber,n02101556 -English springer,n02102040 -Welsh springer spaniel,n02102177 -cocker spaniel,n02102318 -Sussex spaniel,n02102480 -Irish water spaniel,n02102973 -kuvasz,n02104029 -schipperke,n02104365 -groenendael,n02105056 -malinois,n02105162 -briard,n02105251 -kelpie,n02105412 -komondor,n02105505 -Old English sheepdog,n02105641 -Shetland sheepdog,n02105855 -collie,n02106030 -Border collie,n02106166 -Bouvier des Flandres,n02106382 -Rottweiler,n02106550 -German shepherd,n02106662 -Doberman,n02107142 -miniature pinscher,n02107312 -Greater Swiss Mountain dog,n02107574 -Bernese mountain dog,n02107683 -Appenzeller,n02107908 -EntleBucher,n02108000 -boxer,n02108089 -bull mastiff,n02108422 -Tibetan mastiff,n02108551 -French bulldog,n02108915 -Great Dane,n02109047 -Saint Bernard,n02109525 -Eskimo dog,n02109961 -malamute,n02110063 -Siberian husky,n02110185 -dalmatian,n02110341 -affenpinscher,n02110627 -basenji,n02110806 -pug,n02110958 -Leonberg,n02111129 -Newfoundland,n02111277 -Great Pyrenees,n02111500 -Samoyed,n02111889 -Pomeranian,n02112018 -chow,n02112137 -keeshond,n02112350 -Brabancon griffon,n02112706 -Pembroke,n02113023 -Cardigan,n02113186 -toy poodle,n02113624 -miniature poodle,n02113712 -standard poodle,n02113799 -Mexican hairless,n02113978 -timber wolf,n02114367 -white wolf,n02114548 -red wolf,n02114712 -coyote,n02114855 -dingo,n02115641 -dhole,n02115913 -African hunting dog,n02116738 -hyena,n02117135 -red fox,n02119022 -kit fox,n02119789 -Arctic fox,n02120079 -grey fox,n02120505 -tabby,n02123045 -tiger cat,n02123159 -Persian cat,n02123394 -Siamese cat,n02123597 -Egyptian cat,n02124075 -cougar,n02125311 -lynx,n02127052 -leopard,n02128385 -snow leopard,n02128757 -jaguar,n02128925 -lion,n02129165 -tiger,n02129604 -cheetah,n02130308 -brown bear,n02132136 -American black bear,n02133161 -ice bear,n02134084 -sloth bear,n02134418 -mongoose,n02137549 -meerkat,n02138441 -tiger beetle,n02165105 -ladybug,n02165456 -ground beetle,n02167151 -long-horned beetle,n02168699 -leaf beetle,n02169497 -dung beetle,n02172182 -rhinoceros beetle,n02174001 -weevil,n02177972 -fly,n02190166 -bee,n02206856 -ant,n02219486 -grasshopper,n02226429 -cricket,n02229544 -walking stick,n02231487 -cockroach,n02233338 -mantis,n02236044 -cicada,n02256656 -leafhopper,n02259212 -lacewing,n02264363 -dragonfly,n02268443 -damselfly,n02268853 -admiral,n02276258 -ringlet,n02277742 -monarch,n02279972 -cabbage butterfly,n02280649 -sulphur butterfly,n02281406 -lycaenid,n02281787 -starfish,n02317335 -sea urchin,n02319095 -sea cucumber,n02321529 -wood rabbit,n02325366 -hare,n02326432 -Angora,n02328150 -hamster,n02342885 -porcupine,n02346627 -fox squirrel,n02356798 -marmot,n02361337 -beaver,n02363005 -guinea pig,n02364673 -sorrel,n02389026 -zebra,n02391049 -hog,n02395406 -wild boar,n02396427 -warthog,n02397096 -hippopotamus,n02398521 -ox,n02403003 -water buffalo,n02408429 -bison,n02410509 -ram,n02412080 -bighorn,n02415577 -ibex,n02417914 -hartebeest,n02422106 -impala,n02422699 -gazelle,n02423022 -Arabian camel,n02437312 -llama,n02437616 -weasel,n02441942 -mink,n02442845 -polecat,n02443114 -black-footed ferret,n02443484 -otter,n02444819 -skunk,n02445715 -badger,n02447366 -armadillo,n02454379 -three-toed sloth,n02457408 -orangutan,n02480495 -gorilla,n02480855 -chimpanzee,n02481823 -gibbon,n02483362 -siamang,n02483708 -guenon,n02484975 -patas,n02486261 -baboon,n02486410 -macaque,n02487347 -langur,n02488291 -colobus,n02488702 -proboscis monkey,n02489166 -marmoset,n02490219 -capuchin,n02492035 -howler monkey,n02492660 -titi,n02493509 -spider monkey,n02493793 -squirrel monkey,n02494079 -Madagascar cat,n02497673 -indri,n02500267 -Indian elephant,n02504013 -African elephant,n02504458 -lesser panda,n02509815 -giant panda,n02510455 -barracouta,n02514041 -eel,n02526121 -coho,n02536864 -rock beauty,n02606052 -anemone fish,n02607072 -sturgeon,n02640242 -gar,n02641379 -lionfish,n02643566 -puffer,n02655020 -abacus,n02666196 -abaya,n02667093 -academic gown,n02669723 -accordion,n02672831 -acoustic guitar,n02676566 -aircraft carrier,n02687172 -airliner,n02690373 -airship,n02692877 -altar,n02699494 -ambulance,n02701002 -amphibian,n02704792 -analog clock,n02708093 -apiary,n02727426 -apron,n02730930 -ashcan,n02747177 -assault rifle,n02749479 -backpack,n02769748 -bakery,n02776631 -balance beam,n02777292 -balloon,n02782093 -ballpoint,n02783161 -Band Aid,n02786058 -banjo,n02787622 -bannister,n02788148 -barbell,n02790996 -barber chair,n02791124 -barbershop,n02791270 -barn,n02793495 -barometer,n02794156 -barrel,n02795169 -barrow,n02797295 -baseball,n02799071 -basketball,n02802426 -bassinet,n02804414 -bassoon,n02804610 -bathing cap,n02807133 -bath towel,n02808304 -bathtub,n02808440 -beach wagon,n02814533 -beacon,n02814860 -beaker,n02815834 -bearskin,n02817516 -beer bottle,n02823428 -beer glass,n02823750 -bell cote,n02825657 -bib,n02834397 -bicycle-built-for-two,n02835271 -bikini,n02837789 -binder,n02840245 -binoculars,n02841315 -birdhouse,n02843684 -boathouse,n02859443 -bobsled,n02860847 -bolo tie,n02865351 -bonnet,n02869837 -bookcase,n02870880 -bookshop,n02871525 -bottlecap,n02877765 -bow,n02879718 -bow tie,n02883205 -brass,n02892201 -brassiere,n02892767 -breakwater,n02894605 -breastplate,n02895154 -broom,n02906734 -bucket,n02909870 -buckle,n02910353 -bulletproof vest,n02916936 -bullet train,n02917067 -butcher shop,n02927161 -cab,n02930766 -caldron,n02939185 -candle,n02948072 -cannon,n02950826 -canoe,n02951358 -can opener,n02951585 -cardigan,n02963159 -car mirror,n02965783 -carousel,n02966193 -carpenter's kit,n02966687 -carton,n02971356 -car wheel,n02974003 -cash machine,n02977058 -cassette,n02978881 -cassette player,n02979186 -castle,n02980441 -catamaran,n02981792 -CD player,n02988304 -cello,n02992211 -cellular telephone,n02992529 -chain,n02999410 -chainlink fence,n03000134 -chain mail,n03000247 -chain saw,n03000684 -chest,n03014705 -chiffonier,n03016953 -chime,n03017168 -china cabinet,n03018349 -Christmas stocking,n03026506 -church,n03028079 -cinema,n03032252 -cleaver,n03041632 -cliff dwelling,n03042490 -cloak,n03045698 -clog,n03047690 -cocktail shaker,n03062245 -coffee mug,n03063599 -coffeepot,n03063689 -coil,n03065424 -combination lock,n03075370 -computer keyboard,n03085013 -confectionery,n03089624 -container ship,n03095699 -convertible,n03100240 -corkscrew,n03109150 -cornet,n03110669 -cowboy boot,n03124043 -cowboy hat,n03124170 -cradle,n03125729 -construction crane,n03126707 -crash helmet,n03127747 -crate,n03127925 -crib,n03131574 -Crock Pot,n03133878 -croquet ball,n03134739 -crutch,n03141823 -cuirass,n03146219 -dam,n03160309 -desk,n03179701 -desktop computer,n03180011 -dial telephone,n03187595 -diaper,n03188531 -digital clock,n03196217 -digital watch,n03197337 -dining table,n03201208 -dishrag,n03207743 -dishwasher,n03207941 -disk brake,n03208938 -dock,n03216828 -dogsled,n03218198 -dome,n03220513 -doormat,n03223299 -drilling platform,n03240683 -drum,n03249569 -drumstick,n03250847 -dumbbell,n03255030 -Dutch oven,n03259280 -electric fan,n03271574 -electric guitar,n03272010 -electric locomotive,n03272562 -entertainment center,n03290653 -envelope,n03291819 -espresso maker,n03297495 -face powder,n03314780 -feather boa,n03325584 -file,n03337140 -fireboat,n03344393 -fire engine,n03345487 -fire screen,n03347037 -flagpole,n03355925 -flute,n03372029 -folding chair,n03376595 -football helmet,n03379051 -forklift,n03384352 -fountain,n03388043 -fountain pen,n03388183 -four-poster,n03388549 -freight car,n03393912 -French horn,n03394916 -frying pan,n03400231 -fur coat,n03404251 -garbage truck,n03417042 -gasmask,n03424325 -gas pump,n03425413 -goblet,n03443371 -go-kart,n03444034 -golf ball,n03445777 -golfcart,n03445924 -gondola,n03447447 -gong,n03447721 -gown,n03450230 -grand piano,n03452741 -greenhouse,n03457902 -grille,n03459775 -grocery store,n03461385 -guillotine,n03467068 -hair slide,n03476684 -hair spray,n03476991 -half track,n03478589 -hammer,n03481172 -hamper,n03482405 -hand blower,n03483316 -hand-held computer,n03485407 -handkerchief,n03485794 -hard disc,n03492542 -harmonica,n03494278 -harp,n03495258 -harvester,n03496892 -hatchet,n03498962 -holster,n03527444 -home theater,n03529860 -honeycomb,n03530642 -hook,n03532672 -hoopskirt,n03534580 -horizontal bar,n03535780 -horse cart,n03538406 -hourglass,n03544143 -iPod,n03584254 -iron,n03584829 -jack-o'-lantern,n03590841 -jean,n03594734 -jeep,n03594945 -jersey,n03595614 -jigsaw puzzle,n03598930 -jinrikisha,n03599486 -joystick,n03602883 -kimono,n03617480 -knee pad,n03623198 -knot,n03627232 -lab coat,n03630383 -ladle,n03633091 -lampshade,n03637318 -laptop,n03642806 -lawn mower,n03649909 -lens cap,n03657121 -letter opener,n03658185 -library,n03661043 -lifeboat,n03662601 -lighter,n03666591 -limousine,n03670208 -liner,n03673027 -lipstick,n03676483 -Loafer,n03680355 -lotion,n03690938 -loudspeaker,n03691459 -loupe,n03692522 -lumbermill,n03697007 -magnetic compass,n03706229 -mailbag,n03709823 -mailbox,n03710193 -maillot,n03710637 -tank suit,n03710721 -manhole cover,n03717622 -maraca,n03720891 -marimba,n03721384 -mask,n03724870 -matchstick,n03729826 -maypole,n03733131 -maze,n03733281 -measuring cup,n03733805 -medicine chest,n03742115 -megalith,n03743016 -microphone,n03759954 -microwave,n03761084 -military uniform,n03763968 -milk can,n03764736 -minibus,n03769881 -miniskirt,n03770439 -minivan,n03770679 -missile,n03773504 -mitten,n03775071 -mixing bowl,n03775546 -mobile home,n03776460 -Model T,n03777568 -modem,n03777754 -monastery,n03781244 -monitor,n03782006 -moped,n03785016 -mortar,n03786901 -mortarboard,n03787032 -mosque,n03788195 -mosquito net,n03788365 -motor scooter,n03791053 -mountain bike,n03792782 -mountain tent,n03792972 -mouse,n03793489 -mousetrap,n03794056 -moving van,n03796401 -muzzle,n03803284 -nail,n03804744 -neck brace,n03814639 -necklace,n03814906 -nipple,n03825788 -notebook,n03832673 -obelisk,n03837869 -oboe,n03838899 -ocarina,n03840681 -odometer,n03841143 -oil filter,n03843555 -organ,n03854065 -oscilloscope,n03857828 -overskirt,n03866082 -oxcart,n03868242 -oxygen mask,n03868863 -packet,n03871628 -paddle,n03873416 -paddlewheel,n03874293 -padlock,n03874599 -paintbrush,n03876231 -pajama,n03877472 -palace,n03877845 -panpipe,n03884397 -paper towel,n03887697 -parachute,n03888257 -parallel bars,n03888605 -park bench,n03891251 -parking meter,n03891332 -passenger car,n03895866 -patio,n03899768 -pay-phone,n03902125 -pedestal,n03903868 -pencil box,n03908618 -pencil sharpener,n03908714 -perfume,n03916031 -Petri dish,n03920288 -photocopier,n03924679 -pick,n03929660 -pickelhaube,n03929855 -picket fence,n03930313 -pickup,n03930630 -pier,n03933933 -piggy bank,n03935335 -pill bottle,n03937543 -pillow,n03938244 -ping-pong ball,n03942813 -pinwheel,n03944341 -pirate,n03947888 -pitcher,n03950228 -plane,n03954731 -planetarium,n03956157 -plastic bag,n03958227 -plate rack,n03961711 -plow,n03967562 -plunger,n03970156 -Polaroid camera,n03976467 -pole,n03976657 -police van,n03977966 -poncho,n03980874 -pool table,n03982430 -pop bottle,n03983396 -pot,n03991062 -potter's wheel,n03992509 -power drill,n03995372 -prayer rug,n03998194 -printer,n04004767 -prison,n04005630 -projectile,n04008634 -projector,n04009552 -puck,n04019541 -punching bag,n04023962 -purse,n04026417 -quill,n04033901 -quilt,n04033995 -racer,n04037443 -racket,n04039381 -radiator,n04040759 -radio,n04041544 -radio telescope,n04044716 -rain barrel,n04049303 -recreational vehicle,n04065272 -reel,n04067472 -reflex camera,n04069434 -refrigerator,n04070727 -remote control,n04074963 -restaurant,n04081281 -revolver,n04086273 -rifle,n04090263 -rocking chair,n04099969 -rotisserie,n04111531 -rubber eraser,n04116512 -rugby ball,n04118538 -rule,n04118776 -running shoe,n04120489 -safe,n04125021 -safety pin,n04127249 -saltshaker,n04131690 -sandal,n04133789 -sarong,n04136333 -sax,n04141076 -scabbard,n04141327 -scale,n04141975 -school bus,n04146614 -schooner,n04147183 -scoreboard,n04149813 -screen,n04152593 -screw,n04153751 -screwdriver,n04154565 -seat belt,n04162706 -sewing machine,n04179913 -shield,n04192698 -shoe shop,n04200800 -shoji,n04201297 -shopping basket,n04204238 -shopping cart,n04204347 -shovel,n04208210 -shower cap,n04209133 -shower curtain,n04209239 -ski,n04228054 -ski mask,n04229816 -sleeping bag,n04235860 -slide rule,n04238763 -sliding door,n04239074 -slot,n04243546 -snorkel,n04251144 -snowmobile,n04252077 -snowplow,n04252225 -soap dispenser,n04254120 -soccer ball,n04254680 -sock,n04254777 -solar dish,n04258138 -sombrero,n04259630 -soup bowl,n04263257 -space bar,n04264628 -space heater,n04265275 -space shuttle,n04266014 -spatula,n04270147 -speedboat,n04273569 -spider web,n04275548 -spindle,n04277352 -sports car,n04285008 -spotlight,n04286575 -stage,n04296562 -steam locomotive,n04310018 -steel arch bridge,n04311004 -steel drum,n04311174 -stethoscope,n04317175 -stole,n04325704 -stone wall,n04326547 -stopwatch,n04328186 -stove,n04330267 -strainer,n04332243 -streetcar,n04335435 -stretcher,n04336792 -studio couch,n04344873 -stupa,n04346328 -submarine,n04347754 -suit,n04350905 -sundial,n04355338 -sunglass,n04355933 -sunglasses,n04356056 -sunscreen,n04357314 -suspension bridge,n04366367 -swab,n04367480 -sweatshirt,n04370456 -swimming trunks,n04371430 -swing,n04371774 -switch,n04372370 -syringe,n04376876 -table lamp,n04380533 -tank,n04389033 -tape player,n04392985 -teapot,n04398044 -teddy,n04399382 -television,n04404412 -tennis ball,n04409515 -thatch,n04417672 -theater curtain,n04418357 -thimble,n04423845 -thresher,n04428191 -throne,n04429376 -tile roof,n04435653 -toaster,n04442312 -tobacco shop,n04443257 -toilet seat,n04447861 -torch,n04456115 -totem pole,n04458633 -tow truck,n04461696 -toyshop,n04462240 -tractor,n04465501 -trailer truck,n04467665 -tray,n04476259 -trench coat,n04479046 -tricycle,n04482393 -trimaran,n04483307 -tripod,n04485082 -triumphal arch,n04486054 -trolleybus,n04487081 -trombone,n04487394 -tub,n04493381 -turnstile,n04501370 -typewriter keyboard,n04505470 -umbrella,n04507155 -unicycle,n04509417 -upright,n04515003 -vacuum,n04517823 -vase,n04522168 -vault,n04523525 -velvet,n04525038 -vending machine,n04525305 -vestment,n04532106 -viaduct,n04532670 -violin,n04536866 -volleyball,n04540053 -waffle iron,n04542943 -wall clock,n04548280 -wallet,n04548362 -wardrobe,n04550184 -warplane,n04552348 -washbasin,n04553703 -washer,n04554684 -water bottle,n04557648 -water jug,n04560804 -water tower,n04562935 -whiskey jug,n04579145 -whistle,n04579432 -wig,n04584207 -window screen,n04589890 -window shade,n04590129 -Windsor tie,n04591157 -wine bottle,n04591713 -wing,n04592741 -wok,n04596742 -wooden spoon,n04597913 -wool,n04599235 -worm fence,n04604644 -wreck,n04606251 -yawl,n04612504 -yurt,n04613696 -web site,n06359193 -comic book,n06596364 -crossword puzzle,n06785654 -street sign,n06794110 -traffic light,n06874185 -book jacket,n07248320 -menu,n07565083 -plate,n07579787 -guacamole,n07583066 -consomme,n07584110 -hot pot,n07590611 -trifle,n07613480 -ice cream,n07614500 -ice lolly,n07615774 -French loaf,n07684084 -bagel,n07693725 -pretzel,n07695742 -cheeseburger,n07697313 -hotdog,n07697537 -mashed potato,n07711569 -head cabbage,n07714571 -broccoli,n07714990 -cauliflower,n07715103 -zucchini,n07716358 -spaghetti squash,n07716906 -acorn squash,n07717410 -butternut squash,n07717556 -cucumber,n07718472 -artichoke,n07718747 -bell pepper,n07720875 -cardoon,n07730033 -mushroom,n07734744 -Granny Smith,n07742313 -strawberry,n07745940 -orange,n07747607 -lemon,n07749582 -fig,n07753113 -pineapple,n07753275 -banana,n07753592 -jackfruit,n07754684 -custard apple,n07760859 -pomegranate,n07768694 -hay,n07802026 -carbonara,n07831146 -chocolate sauce,n07836838 -dough,n07860988 -meat loaf,n07871810 -pizza,n07873807 -potpie,n07875152 -burrito,n07880968 -red wine,n07892512 -espresso,n07920052 -cup,n07930864 -eggnog,n07932039 -alp,n09193705 -bubble,n09229709 -cliff,n09246464 -coral reef,n09256479 -geyser,n09288635 -lakeside,n09332890 -promontory,n09399592 -sandbar,n09421951 -seashore,n09428293 -valley,n09468604 -volcano,n09472597 -ballplayer,n09835506 -groom,n10148035 -scuba diver,n10565667 -rapeseed,n11879895 -daisy,n11939491 -yellow lady's slipper,n12057211 -corn,n12144580 -acorn,n12267677 -hip,n12620546 -buckeye,n12768682 -coral fungus,n12985857 -agaric,n12998815 -gyromitra,n13037406 -stinkhorn,n13040303 -earthstar,n13044778 -hen-of-the-woods,n13052670 -bolete,n13054560 -ear,n13133613 -toilet tissue,n15075141 diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py deleted file mode 100644 index 3192f1f5503..00000000000 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ /dev/null @@ -1,223 +0,0 @@ -import enum -import pathlib -import re - -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union - -from torchdata.datapipes.iter import ( - Demultiplexer, - Enumerator, - Filter, - IterDataPipe, - IterKeyZipper, - LineReader, - Mapper, - TarArchiveLoader, -) -from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - read_categories_file, - read_mat, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - -NAME = "imagenet" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - categories, wnids = zip(*read_categories_file(NAME)) - return dict(categories=categories, wnids=wnids) - - -class ImageNetResource(ManualDownloadResource): - def __init__(self, **kwargs: Any) -> None: - super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) - - -class ImageNetDemux(enum.IntEnum): - META = 0 - LABEL = 1 - - -class CategoryAndWordNetIDExtractor(IterDataPipe): - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } - - def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None: - self.datapipe = datapipe - - def __iter__(self) -> Iterator[Tuple[str, str]]: - for _, stream in self.datapipe: - synsets = read_mat(stream, squeeze_me=True)["synsets"] - for _, wnid, category, _, num_children, *_ in synsets: - if num_children > 0: - # we are looking at a superclass that has no direct instance - continue - - yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid - - -@register_dataset(NAME) -class ImageNet(Dataset): - """ - - **homepage**: https://www.image-net.org/ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - - info = _info() - categories, wnids = info["categories"], info["wnids"] - self._categories = categories - self._wnids = wnids - self._wnid_to_category = dict(zip(wnids, categories)) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _IMAGES_CHECKSUMS = { - "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", - "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", - "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", - } - - def _resources(self) -> List[OnlineResource]: - name = "test_v10102019" if self._split == "test" else self._split - images = ImageNetResource( - file_name=f"ILSVRC2012_img_{name}.tar", - sha256=self._IMAGES_CHECKSUMS[name], - ) - resources: List[OnlineResource] = [images] - - if self._split == "val": - devkit = ImageNetResource( - file_name="ILSVRC2012_devkit_t12.tar.gz", - sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", - ) - resources.append(devkit) - - return resources - - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - - def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), data - - def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: - return None, data - - def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - return { - "meta.mat": ImageNetDemux.META, - "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, - }.get(pathlib.Path(data[0]).name) - - _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - - def _val_test_image_key(self, path: pathlib.Path) -> int: - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] - - def _prepare_val_data( - self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] - ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - label_data, image_data = data - _, wnid = label_data - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), image_data - - def _prepare_sample( - self, - data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - label_data, (path, buffer) = data - - return dict( - dict(zip(("label", "wnid"), label_data if label_data else (None, None))), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split in {"train", "test"}: - dp = resource_dps[0] - - # the train archive is a tar of tars - if self._split == "train": - dp = TarArchiveLoader(dp) - - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) - else: # config.split == "val": - images_dp, devkit_dp = resource_dps - - meta_dp, label_dp = Demultiplexer( - devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - # We cannot use self._wnids here, since we use a different order than the dataset - meta_dp = CategoryAndWordNetIDExtractor(meta_dp) - wnid_dp = Mapper(meta_dp, getitem(1)) - wnid_dp = Enumerator(wnid_dp, 1) - wnid_map = IterToMapConverter(wnid_dp) - - label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, int) - label_dp = Mapper(label_dp, wnid_map.__getitem__) - label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) - label_dp = hint_shuffling(label_dp) - label_dp = hint_sharding(label_dp) - - dp = IterKeyZipper( - label_dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=path_accessor(self._val_test_image_key), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = Mapper(dp, self._prepare_val_data) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[self._split] - - def _filter_meta(self, data: Tuple[str, Any]) -> bool: - return self._classifiy_devkit(data) == ImageNetDemux.META - - def _generate_categories(self) -> List[Tuple[str, ...]]: - self._split = "val" - resources = self._resources() - - devkit_dp = resources[1].load(self._root) - meta_dp = Filter(devkit_dp, self._filter_meta) - meta_dp = CategoryAndWordNetIDExtractor(meta_dp) - - categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp)) - categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) - return categories_and_wnids diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py deleted file mode 100644 index 7a459b2d0ea..00000000000 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ /dev/null @@ -1,415 +0,0 @@ -import abc -import functools -import operator -import pathlib -import string -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import torch -from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE -from torchvision.prototype.features import Image, Label -from torchvision.prototype.utils._internal import fromfile - -from .._api import register_dataset, register_info - - -prod = functools.partial(functools.reduce, operator.mul) - - -class MNISTFileReader(IterDataPipe[torch.Tensor]): - _DTYPE_MAP = { - 8: torch.uint8, - 9: torch.int8, - 11: torch.int16, - 12: torch.int32, - 13: torch.float32, - 14: torch.float64, - } - - def __init__( - self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int] - ) -> None: - self.datapipe = datapipe - self.start = start - self.stop = stop - - def __iter__(self) -> Iterator[torch.Tensor]: - for _, file in self.datapipe: - read = functools.partial(fromfile, file, byte_order="big") - - magic = int(read(dtype=torch.int32, count=1)) - dtype = self._DTYPE_MAP[magic // 256] - ndim = magic % 256 - 1 - - num_samples = int(read(dtype=torch.int32, count=1)) - shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] - count = prod(shape) if shape else 1 - - start = self.start or 0 - stop = min(self.stop, num_samples) if self.stop else num_samples - - if start: - num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - file.seek(num_bytes_per_value * count * start, 1) - - for _ in range(stop - start): - yield read(dtype=dtype, count=count).reshape(shape) - - -class _MNISTBase(Dataset): - _URL_BASE: Union[str, Sequence[str]] - - @abc.abstractmethod - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - pass - - def _resources(self) -> List[OnlineResource]: - (images_file, images_sha256), ( - labels_file, - labels_sha256, - ) = self._files_and_checksums() - - url_bases = self._URL_BASE - if isinstance(url_bases, str): - url_bases = (url_bases,) - - images_urls = [f"{url_base}/{images_file}" for url_base in url_bases] - images = HttpResource(images_urls[0], sha256=images_sha256, mirrors=images_urls[1:]) - - labels_urls = [f"{url_base}/{labels_file}" for url_base in url_bases] - labels = HttpResource(labels_urls[0], sha256=labels_sha256, mirrors=labels_urls[1:]) - - return [images, labels] - - def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: - return None, None - - _categories: List[str] - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - image, label = data - return dict( - image=Image(image), - label=Label(label, dtype=torch.int64, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, labels_dp = resource_dps - start, stop = self.start_and_stop() - - images_dp = Decompressor(images_dp) - images_dp = MNISTFileReader(images_dp, start=start, stop=stop) - - labels_dp = Decompressor(labels_dp) - labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) - - dp = Zipper(images_dp, labels_dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - -@register_info("mnist") -def _mnist_info() -> Dict[str, Any]: - return dict( - categories=[str(label) for label in range(10)], - ) - - -@register_dataset("mnist") -class MNIST(_MNISTBase): - """ - - **homepage**: http://yann.lecun.com/exdb/mnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE: Union[str, Sequence[str]] = ( - "http://yann.lecun.com/exdb/mnist", - "https://ossci-datasets.s3.amazonaws.com/mnist", - ) - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", - "train-labels-idx1-ubyte.gz": "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", - "t10k-images-idx3-ubyte.gz": "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6", - "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", - } - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "train" if self._split == "train" else "t10k" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - return (images_file, self._CHECKSUMS[images_file]), ( - labels_file, - self._CHECKSUMS[labels_file], - ) - - _categories = _mnist_info()["categories"] - - def __len__(self) -> int: - return 60_000 if self._split == "train" else 10_000 - - -@register_info("fashionmnist") -def _fashionmnist_info() -> Dict[str, Any]: - return dict( - categories=[ - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ], - ) - - -@register_dataset("fashionmnist") -class FashionMNIST(MNIST): - """ - - **homepage**: https://github.com/zalandoresearch/fashion-mnist - """ - - _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84", - "train-labels-idx1-ubyte.gz": "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845", - "t10k-images-idx3-ubyte.gz": "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073", - "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", - } - - _categories = _fashionmnist_info()["categories"] - - -@register_info("kmnist") -def _kmnist_info() -> Dict[str, Any]: - return dict( - categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], - ) - - -@register_dataset("kmnist") -class KMNIST(MNIST): - """ - - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en - """ - - _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4", - "train-labels-idx1-ubyte.gz": "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17", - "t10k-images-idx3-ubyte.gz": "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5", - "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", - } - - _categories = _kmnist_info()["categories"] - - -@register_info("emnist") -def _emnist_info() -> Dict[str, Any]: - return dict( - categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), - ) - - -@register_dataset("emnist") -class EMNIST(_MNISTBase): - """ - - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - image_set: str = "Balanced", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - self._image_set = self._verify_str_arg( - image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST") - ) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them - return (images_file, ""), (labels_file, "") - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - f"{self._URL_BASE}/emnist-gzip.zip", - sha256="909a2a39c5e86bdd7662425e9b9c4a49bb582bf8d0edad427f3c3a9d0c6f7259", - ) - ] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - (images_file, _), (labels_file, _) = self._files_and_checksums() - if path.name == images_file: - return 0 - elif path.name == labels_file: - return 1 - else: - return None - - _categories = _emnist_info()["categories"] - - _LABEL_OFFSETS = { - 38: 1, - 39: 1, - 40: 1, - 41: 1, - 42: 1, - 43: 6, - 44: 8, - 45: 8, - 46: 9, - } - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). - # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, - # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For - # example, since there is no 'c', 'd' corresponds to - # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), - # and at the same time corresponds to - # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) - # in self._categories. Thus, we need to add 1 to the label to correct this. - if self._image_set in ("Balanced", "By_Merge"): - image, label = data - label += self._LABEL_OFFSETS.get(int(label), 0) - data = (image, label) - return super()._prepare_sample(data) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, labels_dp = Demultiplexer( - archive_dp, - 2, - self._classify_archive, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - return super()._datapipe([images_dp, labels_dp]) - - def __len__(self) -> int: - return { - ("train", "Balanced"): 112_800, - ("train", "By_Merge"): 697_932, - ("train", "By_Class"): 697_932, - ("train", "Letters"): 124_800, - ("train", "Digits"): 240_000, - ("train", "MNIST"): 60_000, - ("test", "Balanced"): 18_800, - ("test", "By_Merge"): 116_323, - ("test", "By_Class"): 116_323, - ("test", "Letters"): 20_800, - ("test", "Digits"): 40_000, - ("test", "MNIST"): 10_000, - }[(self._split, self._image_set)] - - -@register_info("qmnist") -def _qmnist_info() -> Dict[str, Any]: - return dict( - categories=[str(label) for label in range(10)], - ) - - -@register_dataset("qmnist") -class QMNIST(_MNISTBase): - """ - - **homepage**: https://github.com/facebookresearch/qmnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" - _CHECKSUMS = { - "qmnist-train-images-idx3-ubyte.gz": "9e26a7bf1683614e065d7b76460ccd52807165b3f22561fb782bd9f38c52b51d", - "qmnist-train-labels-idx2-int.gz": "2c05dc77f6b916b38e455e97ab129a42a444f3dbef09b278a366f82904e0dd9f", - "qmnist-test-images-idx3-ubyte.gz": "43fc22bf7498b8fc98de98369d72f752d0deabc280a43a7bcc364ab19e57b375", - "qmnist-test-labels-idx2-int.gz": "9fbcbe594c3766fdf4f0b15c5165dc0d1e57ac604e01422608bb72c906030d06", - "xnist-images-idx3-ubyte.xz": "f075553993026d4359ded42208eff77a1941d3963c1eff49d6015814f15f0984", - "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", - } - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}" - suffix = "xz" if self._split == "nist" else "gz" - images_file = f"{prefix}-images-idx3-ubyte.{suffix}" - labels_file = f"{prefix}-labels-idx2-int.{suffix}" - return (images_file, self._CHECKSUMS[images_file]), ( - labels_file, - self._CHECKSUMS[labels_file], - ) - - def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: - start: Optional[int] - stop: Optional[int] - if self._split == "test10k": - start = 0 - stop = 10000 - elif self._split == "test50k": - start = 10000 - stop = None - else: - start = stop = None - - return start, stop - - _categories = _emnist_info()["categories"] - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - image, ann = data - label, *extra_anns = ann - sample = super()._prepare_sample((image, label)) - - sample.update( - dict( - zip( - ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), - [int(value) for value in extra_anns[:5]], - ) - ) - ) - sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) - return sample - - def __len__(self) -> int: - return { - "train": 60_000, - "test": 60_000, - "test10k": 10_000, - "test50k": 50_000, - "nist": 402_953, - }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories b/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories deleted file mode 100644 index 36d29465b04..00000000000 --- a/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories +++ /dev/null @@ -1,37 +0,0 @@ -Abyssinian -American Bulldog -American Pit Bull Terrier -Basset Hound -Beagle -Bengal -Birman -Bombay -Boxer -British Shorthair -Chihuahua -Egyptian Mau -English Cocker Spaniel -English Setter -German Shorthaired -Great Pyrenees -Havanese -Japanese Chin -Keeshond -Leonberger -Maine Coon -Miniature Pinscher -Newfoundland -Persian -Pomeranian -Pug -Ragdoll -Russian Blue -Saint Bernard -Samoyed -Scottish Terrier -Shiba Inu -Siamese -Sphynx -Staffordshire Bull Terrier -Wheaten Terrier -Yorkshire Terrier diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py deleted file mode 100644 index 499dbd837ed..00000000000 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ /dev/null @@ -1,146 +0,0 @@ -import enum -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, -) -from torchvision.prototype.features import EncodedImage, Label - -from .._api import register_dataset, register_info - - -NAME = "oxford-iiit-pet" - - -class OxfordIIITPetDemux(enum.IntEnum): - SPLIT_AND_CLASSIFICATION = 0 - SEGMENTATIONS = 1 - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class OxfordIIITPet(Dataset): - """Oxford IIIT Pet Dataset - homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"trainval", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - images = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", - sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", - preprocess="decompress", - ) - anns = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", - sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91", - preprocess="decompress", - ) - return [images, anns] - - def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: - return { - "annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION, - "trimaps": OxfordIIITPetDemux.SEGMENTATIONS, - }.get(pathlib.Path(data[0]).parent.name) - - def _filter_images(self, data: Tuple[str, Any]) -> bool: - return pathlib.Path(data[0]).suffix == ".jpg" - - def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: - return not pathlib.Path(data[0]).name.startswith(".") - - def _prepare_sample( - self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]] - ) -> Dict[str, Any]: - ann_data, image_data = data - classification_data, segmentation_data = ann_data - segmentation_path, segmentation_buffer = segmentation_data - image_path, image_buffer = image_data - - return dict( - label=Label(int(classification_data["label"]) - 1, categories=self._categories), - species="cat" if classification_data["species"] == "1" else "dog", - segmentation_path=segmentation_path, - segmentation=EncodedImage.from_file(segmentation_buffer), - image_path=image_path, - image=EncodedImage.from_file(image_buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, anns_dp = resource_dps - - images_dp = Filter(images_dp, self._filter_images) - - split_and_classification_dp, segmentations_dp = Demultiplexer( - anns_dp, - 2, - self._classify_anns, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt")) - split_and_classification_dp = CSVDictParser( - split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " - ) - split_and_classification_dp = hint_shuffling(split_and_classification_dp) - split_and_classification_dp = hint_sharding(split_and_classification_dp) - - segmentations_dp = Filter(segmentations_dp, self._filter_segmentations) - - anns_dp = IterKeyZipper( - split_and_classification_dp, - segmentations_dp, - key_fn=getitem("image_id"), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - - dp = IterKeyZipper( - anns_dp, - images_dp, - key_fn=getitem(0, "image_id"), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[1].load(self._root) - dp = Filter(dp, self._filter_split_and_classification_anns) - dp = Filter(dp, path_comparator("name", "trainval.txt")) - dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") - - raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} - raw_categories, _ = zip( - *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) - ) - return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories] - - def __len__(self) -> int: - return 3_680 if self._split == "trainval" else 3_669 diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py deleted file mode 100644 index 162f22f1abd..00000000000 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ /dev/null @@ -1,126 +0,0 @@ -import io -import pathlib -from collections import namedtuple -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper -from torchvision.prototype import features -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label - -from .._api import register_dataset, register_info - - -NAME = "pcam" - - -class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): - def __init__( - self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], - key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset - ) -> None: - self.datapipe = datapipe - self.key = key - - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: - import h5py - - for _, handle in self.datapipe: - with h5py.File(handle) as data: - if self.key is not None: - data = data[self.key] - yield from data - - -_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=["0", "1"]) - - -@register_dataset(NAME) -class PCAM(Dataset): - # TODO write proper docstring - """PCAM Dataset - - homepage="https://github.com/basveeling/pcam" - """ - - def __init__( - self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",)) - - _RESOURCES = { - "train": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_train_x.h5.gz", - gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", - sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_train_y.h5.gz", - gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG", - sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a", - ), - ), - "test": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_test_x.h5.gz", - gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_", - sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_test_y.h5.gz", - gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP", - sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef", - ), - ), - "val": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_valid_x.h5.gz", - gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3", - sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_valid_y.h5.gz", - gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO", - sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c", - ), - ), - } - - def _resources(self) -> List[OnlineResource]: - return [ # = [images resource, targets resource] - GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") - for file_name, gdrive_id, sha256 in self._RESOURCES[self._split] - ] - - def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: - image, target = data # They're both numpy arrays at this point - - return { - "image": features.Image(image.transpose(2, 0, 1)), - "label": Label(target.item(), categories=self._categories), - } - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - - images_dp, targets_dp = resource_dps - - images_dp = PCAMH5Reader(images_dp, key="x") - targets_dp = PCAMH5Reader(targets_dp, key="y") - - dp = Zipper(images_dp, targets_dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 262_144 if self._split == "train" else 32_768 diff --git a/torchvision/prototype/datasets/_builtin/sbd.categories b/torchvision/prototype/datasets/_builtin/sbd.categories deleted file mode 100644 index 8420ab35ede..00000000000 --- a/torchvision/prototype/datasets/_builtin/sbd.categories +++ /dev/null @@ -1,20 +0,0 @@ -aeroplane -bicycle -bird -boat -bottle -bus -car -cat -chair -cow -diningtable -dog -horse -motorbike -person -pottedplant -sheep -sofa -train -tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py deleted file mode 100644 index c7a79c4188e..00000000000 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ /dev/null @@ -1,153 +0,0 @@ -import pathlib -import re -from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union - -import numpy as np -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, - read_mat, -) -from torchvision.prototype.features import _Feature, EncodedImage - -from .._api import register_dataset, register_info - -NAME = "sbd" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class SBD(Dataset): - """ - - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html - - **dependencies**: - - _ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) - - self._categories = _info()["categories"] - - super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - archive = HttpResource( - "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", - sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", - ) - extra_split = HttpResource( - "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", - sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", - ) - return [archive, extra_split] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - parent, grandparent, *_ = path.parents - - if parent.name == "dataset": - return 0 - elif grandparent.name == "dataset": - if parent.name == "img": - return 1 - elif parent.name == "cls": - return 2 - else: - return None - else: - return None - - def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: - split_and_image_data, ann_data = data - _, image_data = split_and_image_data - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"] - - return dict( - image_path=image_path, - image=EncodedImage.from_file(image_buffer), - ann_path=ann_path, - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), - segmentation=_Feature(anns["Segmentation"].item()), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp, extra_split_dp = resource_dps - - archive_dp = resource_dps[0] - split_dp, images_dp, anns_dp = Demultiplexer( - archive_dp, - 3, - self._classify_archive, - buffer_size=INFINITE_BUFFER_SIZE, - drop_none=True, - ) - if self._split == "train_noval": - split_dp = extra_split_dp - - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True) - split_dp = hint_shuffling(split_dp) - split_dp = hint_sharding(split_dp) - - dp = split_dp - for level, data_dp in enumerate((images_dp, anns_dp)): - dp = IterKeyZipper( - dp, - data_dp, - key_fn=getitem(*[0] * level, 1), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 8_498, - "val": 2_857, - "train_noval": 5_623, - }[self._split] - - def _generate_categories(self) -> Tuple[str, ...]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "category_names.m")) - dp = LineReader(dp) - dp = Mapper(dp, bytes.decode, input_col=1) - lines = tuple(zip(*iter(dp)))[1] - - pattern = re.compile(r"\s*'(?P\w+)';\s*%(?P