diff --git a/CHANGELOG.md b/CHANGELOG.md index 45b335f9481..e96df58d964 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021)) +- Added support to segmentation type `segm` as IOU for mean average precision ([#822](https://github.com/PyTorchLightning/metrics/pull/822)) ### Changed diff --git a/requirements/detection.txt b/requirements/detection.txt index cbec7260775..da90ea549b7 100644 --- a/requirements/detection.txt +++ b/requirements/detection.txt @@ -1 +1,2 @@ torchvision>=0.8 +pycocotools diff --git a/requirements/detection_test.txt b/requirements/detection_test.txt new file mode 100644 index 00000000000..88d1f3a15b5 --- /dev/null +++ b/requirements/detection_test.txt @@ -0,0 +1 @@ +pycocotools diff --git a/requirements/devel.txt b/requirements/devel.txt index 2982b7de3ce..757c79a82ae 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -14,3 +14,4 @@ -r image_test.txt -r text_test.txt -r audio_test.txt +-r detection_test.txt diff --git a/tests/detection/__init__.py b/tests/detection/__init__.py index e69de29bb2d..0904ebc12bc 100644 --- a/tests/detection/__init__.py +++ b/tests/detection/__init__.py @@ -0,0 +1,5 @@ +import os + +from tests import _PATH_ROOT + +_SAMPLE_DETECTION_SEGMENTATION = os.path.join(_PATH_ROOT, "_data", "detection", "instance_segmentation_inputs.json") diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 8cbb964aa08..b5277e1ab5d 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -12,18 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from collections import namedtuple +import numpy as np import pytest import torch +from pycocotools import mask from torch import IntTensor, Tensor +from tests.detection import _SAMPLE_DETECTION_SEGMENTATION from tests.helpers.testers import MetricTester from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 Input = namedtuple("Input", ["preds", "target"]) +with open(_SAMPLE_DETECTION_SEGMENTATION) as fp: + inputs_json = json.load(fp) + +_mask_unsqueeze_bool = lambda m: Tensor(mask.decode(m)).unsqueeze(0).bool() +_masks_stack_bool = lambda ms: Tensor(np.stack([mask.decode(m) for m in ms])).bool() + +_inputs_masks = Input( + preds=[ + [ + dict(masks=_mask_unsqueeze_bool(inputs_json["preds"][0]), scores=Tensor([0.236]), labels=IntTensor([4])), + dict( + masks=_masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]), + scores=Tensor([0.318, 0.726]), + labels=IntTensor([3, 2]), + ), # 73 + ], + ], + target=[ + [ + dict(masks=_mask_unsqueeze_bool(inputs_json["targets"][0]), labels=IntTensor([4])), # 42 + dict( + masks=_masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]), + labels=IntTensor([2, 2]), + ), # 73 + ], + ], +) + + _inputs = Input( preds=[ [ @@ -139,15 +172,15 @@ _inputs3 = Input( preds=[ [ - dict(boxes=torch.tensor([]), scores=torch.tensor([]), labels=torch.tensor([])), + dict(boxes=Tensor([]), scores=Tensor([]), labels=Tensor([])), ], ], target=[ [ dict( - boxes=torch.tensor([[1.0, 2.0, 3.0, 4.0]]), - scores=torch.tensor([0.8]), - labels=torch.tensor([1]), + boxes=Tensor([[1.0, 2.0, 3.0, 4.0]]), + scores=Tensor([0.8]), + labels=Tensor([1]), ), ], ], @@ -214,6 +247,41 @@ def _compare_fn(preds, target) -> dict: } +def _compare_fn_segm(preds, target) -> dict: + """Comparison function for map implementation for instance segmentation. + + Official pycocotools results calculated from a subset of https://github.com/cocodataset/cocoapi/tree/master/results + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.352 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.752 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.252 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.352 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.350 + """ + return { + "map": Tensor([0.352]), + "map_50": Tensor([0.742]), + "map_75": Tensor([0.252]), + "map_small": Tensor([-1]), + "map_medium": Tensor([-1]), + "map_large": Tensor([0.352]), + "mar_1": Tensor([0.35]), + "mar_10": Tensor([0.35]), + "mar_100": Tensor([0.35]), + "mar_small": Tensor([-1]), + "mar_medium": Tensor([-1]), + "mar_large": Tensor([0.35]), + "map_per_class": Tensor([0.4039604, -1.0, 0.3]), + "mar_100_per_class": Tensor([0.4, -1.0, 0.3]), + } + + _pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) @@ -230,7 +298,8 @@ class TestMAP(MetricTester): atol = 1e-1 @pytest.mark.parametrize("ddp", [False, True]) - def test_map(self, compute_on_cpu, ddp): + def test_map_bbox(self, compute_on_cpu, ddp): + """Test modular implementation for correctness.""" self.run_class_metric_test( ddp=ddp, @@ -243,6 +312,21 @@ def test_map(self, compute_on_cpu, ddp): metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu}, ) + @pytest.mark.parametrize("ddp", [False]) + def test_map_segm(self, compute_on_cpu, ddp): + """Test modular implementation for correctness.""" + + self.run_class_metric_test( + ddp=ddp, + preds=_inputs_masks.preds, + target=_inputs_masks.target, + metric_class=MeanAveragePrecision, + sk_metric=_compare_fn_segm, + dist_sync_on_step=False, + check_batch=False, + metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu, "iou_type": "segm"}, + ) + # noinspection PyTypeChecker @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") @@ -377,6 +461,27 @@ def test_missing_gt(): assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions." +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +def test_segm_iou_empty_mask(): + """Test empty ground truths.""" + metric = MeanAveragePrecision(iou_type="segm") + + metric.update( + [ + dict( + masks=torch.randint(0, 1, (1, 10, 10)).bool(), + scores=Tensor([0.5]), + labels=IntTensor([4]), + ), + ], + [ + dict(masks=Tensor([]), labels=IntTensor([])), + ], + ) + + metric.compute() + + @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") def test_error_on_wrong_input(): """Test class input validation.""" diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 7bc0c3cb2e5..936b12be957 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -212,6 +212,7 @@ def _class_test( _assert_allclose(batch_result, sk_batch_result, atol=atol) # check that metrics are hashable + assert hash(metric) # assert that state dict is empty diff --git a/torchmetrics/detection/mean_ap.py b/torchmetrics/detection/mean_ap.py index 7f3fef838b8..0fdf7ab5b39 100644 --- a/torchmetrics/detection/mean_ap.py +++ b/torchmetrics/detection/mean_ap.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np import torch from torch import IntTensor, Tensor from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8 +from torchmetrics.utilities.imports import _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 if _TORCHVISION_GREATER_EQUAL_0_8: from torchvision.ops import box_area, box_convert, box_iou @@ -26,9 +27,52 @@ box_convert = box_iou = box_area = None __doctest_skip__ = ["MeanAveragePrecision"] +if _PYCOCOTOOLS_AVAILABLE: + import pycocotools.mask as mask_utils +else: + mask_utils = None + __doctest_skip__ = ["MeanAveragePrecision"] + + log = logging.getLogger(__name__) +def compute_area(input: List[Any], iou_type: str = "bbox") -> Tensor: + """Compute area of input depending on the specified iou_type. + + Default output for empty input is torch.Tensor([]) + """ + if len(input) == 0: + + return torch.Tensor([]) + + if iou_type == "bbox": + return box_area(torch.stack(input)) + elif iou_type == "segm": + + input = [{"size": i[0], "counts": i[1]} for i in input] + area = torch.tensor(mask_utils.area(input).astype("float")) + + return area + else: + raise Exception(f"IOU type {iou_type} is not supported") + + +def compute_iou( + det: List[Any], + gt: List[Any], + iou_type: str = "bbox", +) -> Tensor: + """Compute IOU between detections and ground-truth using the specified iou_type.""" + + if iou_type == "bbox": + return box_iou(torch.stack(det), torch.stack(gt)) + elif iou_type == "segm": + return _segm_iou(det, gt) + else: + raise Exception(f"IOU type {iou_type} is not supported") + + class BaseMetricResults(dict): """Base metric class, that allows fields for pre-defined metrics.""" @@ -80,7 +124,27 @@ class COCOMetricResults(BaseMetricResults): ) -def _input_validator(preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]]) -> None: +def _segm_iou(det: List[Tuple[np.ndarray, np.ndarray]], gt: List[Tuple[np.ndarray, np.ndarray]]) -> torch.Tensor: + """ + Compute IOU between detections and ground-truths using mask-IOU. Based on pycocotools toolkit for mask_utils + Args: + det: A list of detection masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension + of the input and RLE_COUNTS is its RLE representation; + + gt: A list of ground-truth masks as ``[(RLE_SIZE, RLE_COUNTS)]``, where ``RLE_SIZE`` is (width, height) dimension + of the input and RLE_COUNTS is its RLE representation; + + """ + + det_coco_format = [{"size": i[0], "counts": i[1]} for i in det] + gt_coco_format = [{"size": i[0], "counts": i[1]} for i in gt] + + return torch.tensor(mask_utils.iou(det_coco_format, gt_coco_format, [False for _ in gt])) + + +def _input_validator( + preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox" +) -> None: """Ensure the correct input format of `preds` and `targets`""" if not isinstance(preds, Sequence): raise ValueError("Expected argument `preds` to be of type Sequence") @@ -88,43 +152,45 @@ def _input_validator(preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[ raise ValueError("Expected argument `target` to be of type Sequence") if len(preds) != len(targets): raise ValueError("Expected argument `preds` and `target` to have the same length") + iou_attribute = "boxes" if iou_type == "bbox" else "masks" - for k in ["boxes", "scores", "labels"]: + for k in [iou_attribute, "scores", "labels"]: if any(k not in p for p in preds): raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") - for k in ["boxes", "labels"]: + for k in [iou_attribute, "labels"]: if any(k not in p for p in targets): raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") - if any(type(pred["boxes"]) is not Tensor for pred in preds): - raise ValueError("Expected all boxes in `preds` to be of type Tensor") + if any(type(pred[iou_attribute]) is not Tensor for pred in preds): + raise ValueError(f"Expected all {iou_attribute} in `preds` to be of type Tensor") if any(type(pred["scores"]) is not Tensor for pred in preds): raise ValueError("Expected all scores in `preds` to be of type Tensor") if any(type(pred["labels"]) is not Tensor for pred in preds): raise ValueError("Expected all labels in `preds` to be of type Tensor") - if any(type(target["boxes"]) is not Tensor for target in targets): - raise ValueError("Expected all boxes in `target` to be of type Tensor") + if any(type(target[iou_attribute]) is not Tensor for target in targets): + raise ValueError(f"Expected all {iou_attribute} in `target` to be of type Tensor") if any(type(target["labels"]) is not Tensor for target in targets): raise ValueError("Expected all labels in `target` to be of type Tensor") for i, item in enumerate(targets): - if item["boxes"].size(0) != item["labels"].size(0): + if item[iou_attribute].size(0) != item["labels"].size(0): raise ValueError( - f"Input boxes and labels of sample {i} in targets have a" - f" different length (expected {item['boxes'].size(0)} labels, got {item['labels'].size(0)})" + f"Input {iou_attribute} and labels of sample {i} in targets have a" + f" different length (expected {item[iou_attribute].size(0)} labels, got {item['labels'].size(0)})" ) for i, item in enumerate(preds): - if not (item["boxes"].size(0) == item["labels"].size(0) == item["scores"].size(0)): + if not (item[iou_attribute].size(0) == item["labels"].size(0) == item["scores"].size(0)): raise ValueError( - f"Input boxes, labels and scores of sample {i} in predictions have a" - f" different length (expected {item['boxes'].size(0)} labels and scores," + f"Input {iou_attribute}, labels and scores of sample {i} in predictions have a" + f" different length (expected {item[iou_attribute].size(0)} labels and scores," f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" ) def _fix_empty_tensors(boxes: Tensor) -> Tensor: """Empty tensors can cause problems in DDP mode, this methods corrects them.""" + if boxes.numel() == 0 and boxes.ndim == 1: return boxes.unsqueeze(0) return boxes @@ -150,13 +216,17 @@ class MeanAveragePrecision(Metric): a standard implementation for the mAP metric for object detection. .. note:: - This metric requires you to have `torchvision` version 0.8.0 or newer installed (with corresponding - version 1.7.0 of torch or newer). Please install with ``pip install torchvision`` or + This metric requires you to have `torchvision` version 0.8.0 or newer installed + (with corresponding version 1.7.0 of torch or newer). This metric requires `pycocotools` + installed when iou_type is `segm`. Please install with ``pip install torchvision`` or ``pip install torchmetrics[detection]``. Args: box_format: Input format of given boxes. Supported formats are ``[`xyxy`, `xywh`, `cxcywh`]``. + iou_type: + Type of input (either masks or bounding-boxes) used for computing IOU. + Supported IOU types are ``[`bboxes`, `segm`]``. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. @@ -208,6 +278,8 @@ class MeanAveragePrecision(Metric): Raises: ModuleNotFoundError: If ``torchvision`` is not installed or version installed is lower than 0.8.0 + ModuleNotFoundError: + If ``iou_type`` is equal to ``seqm`` and ``pycocotools`` is not installed ValueError: If ``class_metrics`` is not a boolean """ @@ -215,15 +287,16 @@ class MeanAveragePrecision(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = True - detection_boxes: List[Tensor] + detections: List[Tensor] detection_scores: List[Tensor] detection_labels: List[Tensor] - groundtruth_boxes: List[Tensor] + groundtruths: List[Tensor] groundtruth_labels: List[Tensor] def __init__( self, box_format: str = "xyxy", + iou_type: str = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, @@ -239,6 +312,7 @@ def __init__( ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") + allowed_iou_types = ("segm", "bbox") if box_format not in allowed_box_formats: raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format @@ -246,6 +320,11 @@ def __init__( self.rec_thresholds = rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1).tolist() max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100])) self.max_detection_thresholds = max_det_thr.tolist() + if iou_type not in allowed_iou_types: + raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") + if iou_type == "segm" and not _PYCOCOTOOLS_AVAILABLE: + raise ModuleNotFoundError("When `iou_type` is set to 'segm', pycocotools need to be installed") + self.iou_type = iou_type self.bbox_area_ranges = { "all": (0**2, int(1e5**2)), "small": (0**2, 32**2), @@ -257,10 +336,10 @@ def __init__( raise ValueError("Expected argument `class_metrics` to be a boolean") self.class_metrics = class_metrics - self.add_state("detection_boxes", default=[], dist_reduce_fx=None) + self.add_state("detections", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) - self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None) + self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore @@ -304,20 +383,51 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] ValueError: If any score is not type float and of length 1 """ - _input_validator(preds, target) + _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - boxes = _fix_empty_tensors(item["boxes"]) - boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") - self.detection_boxes.append(boxes) + + detections = self._get_safe_item_values(item) + + self.detections.append(detections) self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: + groundtruths = self._get_safe_item_values(item) + self.groundtruths.append(groundtruths) + self.groundtruth_labels.append(item["labels"]) + + def _move_list_states_to_cpu(self) -> None: + """Move list states to cpu to save GPU memory.""" + + for key in self._defaults.keys(): + current_val = getattr(self, key) + current_to_cpu = [] + if isinstance(current_val, Sequence): + for cur_v in current_val: + # Cannot handle RLE as torch.Tensor + if not isinstance(cur_v, tuple): + cur_v = cur_v.to("cpu") + current_to_cpu.append(cur_v) + setattr(self, key, current_to_cpu) + + def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + + if self.iou_type == "bbox": boxes = _fix_empty_tensors(item["boxes"]) boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy") - self.groundtruth_boxes.append(boxes) - self.groundtruth_labels.append(item["labels"]) + return boxes + elif self.iou_type == "segm": + masks = [] + + for i in item["masks"].cpu().numpy(): + rle = mask_utils.encode(np.asfortranarray(i)) + masks.append((tuple(rle["size"]), rle["counts"])) + + return tuple(masks) + else: + raise Exception(f"IOU type {self.iou_type} is not supported") def _get_classes(self) -> List: """Returns a list of unique classes found in ground truth and detection data.""" @@ -337,14 +447,20 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: max_det: Maximum number of evaluated detection bounding boxes """ - gt = self.groundtruth_boxes[idx] - det = self.detection_boxes[idx] - gt_label_mask = self.groundtruth_labels[idx] == class_id - det_label_mask = self.detection_labels[idx] == class_id + + # if self.iou_type == "bbox": + gt = self.groundtruths[idx] + det = self.detections[idx] + + gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) + det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) + if len(gt_label_mask) == 0 or len(det_label_mask) == 0: return Tensor([]) - gt = gt[gt_label_mask] - det = det[det_label_mask] + + gt = [gt[i] for i in gt_label_mask] + det = [det[i] for i in det_label_mask] + if len(gt) == 0 or len(det) == 0: return Tensor([]) @@ -352,12 +468,13 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: scores = self.detection_scores[idx] scores_filtered = scores[self.detection_labels[idx] == class_id] inds = torch.argsort(scores_filtered, descending=True) - det = det[inds] + + # TODO Fix (only for masks is necessary) + det = [det[i] for i in inds] if len(det) > max_det: det = det[:max_det] - # generalized_box_iou - ious = box_iou(det, gt) + ious = compute_iou(det, gt, self.iou_type).to(self.device) return ious def __evaluate_image_gt_no_preds( @@ -390,18 +507,21 @@ def __evaluate_image_preds_no_gt( """Some predictions but no GT.""" # GTs nb_gt = 0 + gt_ignore = torch.zeros(nb_gt, dtype=torch.bool, device=self.device) # Detections - det = det[det_label_mask] + + det = [det[i] for i in det_label_mask] scores = self.detection_scores[idx] scores_filtered = scores[det_label_mask] scores_sorted, dtind = torch.sort(scores_filtered, descending=True) - det = det[dtind] + + det = [det[i] for i in dtind] if len(det) > max_det: det = det[:max_det] nb_det = len(det) - det_areas = box_area(det).to(self.device) + det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) ar = det_ignore_area.reshape((1, nb_det)) det_ignore = torch.repeat_interleave(ar, nb_iou_thrs, 0) @@ -409,9 +529,9 @@ def __evaluate_image_preds_no_gt( return { "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), - "dtScores": scores_sorted, - "gtIgnore": gt_ignore, - "dtIgnore": det_ignore, + "dtScores": scores_sorted.to(self.device), + "gtIgnore": gt_ignore.to(self.device), + "dtIgnore": det_ignore.to(self.device), } def _evaluate_image( @@ -431,10 +551,11 @@ def _evaluate_image( ious: IoU results for image and class. """ - gt = self.groundtruth_boxes[idx] - det = self.detection_boxes[idx] - gt_label_mask = self.groundtruth_labels[idx] == class_id - det_label_mask = self.detection_labels[idx] == class_id + + gt = self.groundtruths[idx] + det = self.detections[idx] + gt_label_mask = (self.groundtruth_labels[idx] == class_id).nonzero().squeeze(1) + det_label_mask = (self.detection_labels[idx] == class_id).nonzero().squeeze(1) # No Gt and No predictions --> ignore image if len(gt_label_mask) == 0 and len(det_label_mask) == 0: @@ -450,23 +571,30 @@ def _evaluate_image( if len(gt_label_mask) == 0 and len(det_label_mask) >= 0: return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, nb_iou_thrs) - gt = gt[gt_label_mask] - det = det[det_label_mask] - if gt.numel() == 0 and det.numel() == 0: + gt = [gt[i] for i in gt_label_mask] + det = [det[i] for i in det_label_mask] + if len(gt) == 0 and len(det) == 0: return None + if isinstance(det, dict): + det = [det] + if isinstance(gt, dict): + gt = [gt] - areas = box_area(gt) - ignore_area = (areas < area_range[0]) | (areas > area_range[1]) + areas = compute_area(gt, iou_type=self.iou_type).to(self.device) + + ignore_area = torch.logical_or(areas < area_range[0], areas > area_range[1]) # sort dt highest score first, sort gt ignore last ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8)) # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA" - ignore_area_sorted = ignore_area_sorted.to(torch.bool) - gt = gt[gtind] + + ignore_area_sorted = ignore_area_sorted.to(torch.bool).to(self.device) + + gt = [gt[i] for i in gtind] scores = self.detection_scores[idx] scores_filtered = scores[det_label_mask] scores_sorted, dtind = torch.sort(scores_filtered, descending=True) - det = det[dtind] + det = [det[i] for i in dtind] if len(det) > max_det: det = det[:max_det] # load computed ious @@ -475,10 +603,10 @@ def _evaluate_image( nb_iou_thrs = len(self.iou_thresholds) nb_gt = len(gt) nb_det = len(det) - gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=gt.device) - det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=gt.device) + gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device) + det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) gt_ignore = ignore_area_sorted - det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=gt.device) + det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) if torch.numel(ious) > 0: for idx_iou, t in enumerate(self.iou_thresholds): @@ -491,12 +619,13 @@ def _evaluate_image( gt_matches[idx_iou, m] = 1 # set unmatched detections outside of area range to ignore - det_areas = box_area(det) + det_areas = compute_area(det, iou_type=self.iou_type).to(self.device) det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) ar = det_ignore_area.reshape((1, nb_det)) det_ignore = torch.logical_or( det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0)) ) + return { "dtMatches": det_matches.to(self.device), "gtMatches": gt_matches.to(self.device), @@ -586,7 +715,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult class_ids: List of label class Ids. """ - img_ids = range(len(self.groundtruth_boxes)) + img_ids = range(len(self.groundtruths)) max_detections = self.max_detection_thresholds[-1] area_ranges = self.bbox_area_ranges.values() @@ -692,6 +821,7 @@ def __calculate_recall_precision_scores( img_eval_cls_bbox = [e for e in img_eval_cls_bbox if e is not None] if not img_eval_cls_bbox: return recall, precision, scores + det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox]) # different sorting method generates slightly different results. @@ -726,7 +856,8 @@ def __calculate_recall_precision_scores( diff_zero = torch.zeros((1,), device=pr.device) diff = torch.ones((1,), device=pr.device) while not torch.all(diff == 0): - diff = torch.clamp(torch.cat((pr[1:] - pr[:-1], diff_zero), 0), min=0) + + diff = torch.clamp(torch.cat(((pr[1:] - pr[:-1]), diff_zero), 0), min=0) pr += diff inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) @@ -766,6 +897,7 @@ def compute(self) -> dict: - map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled) - mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled) """ + classes = self._get_classes() precisions, recalls = self._calculate(classes) map_val, mar_val = self._summarize_results(precisions, recalls) @@ -792,4 +924,5 @@ def compute(self) -> dict: metrics.update(mar_val) metrics.map_per_class = map_per_class_values metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values + return metrics diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ec857708e97..67e53d4130e 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -360,6 +360,7 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) + if isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list):