diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index be618359dfe..7b6fede3d40 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -80,6 +80,27 @@ class COCOMetricResults(BaseMetricResults): ) +def segm_iou(inputs, targets, smooth=1): + + n_inputs = inputs.shape[0] + n_targets = targets.shape[0] + # flatten label and prediction tensors + inputs = inputs.view(n_inputs, -1).repeat_interleave(n_targets, 0) + targets = targets.view(n_targets, -1).repeat(n_inputs, 1) + + # i1 * t1 + # i1 * t2 + # i2 * t1 + # i2 * t2 + + # intersection is equivalent to True Positive count + # union is the mutually inclusive area of all labels & predictions + intersections = (inputs * targets).sum(1, keepdims=True) + unions = (inputs + targets).sum(1, keepdims=True) + + return ((intersections + smooth) / (unions + smooth)).view(n_inputs, n_targets) + + def _input_validator(preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]]) -> None: """Ensure the correct input format of `preds` and `targets`""" if not isinstance(preds, Sequence): @@ -225,6 +246,7 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: str = "xyxy", + iou_type: str = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, @@ -248,9 +270,13 @@ def __init__( ) allowed_box_formats = ("xyxy", "xywh", "cxcywh") + allowed_iou_types = ("segm", "bbox") if box_format not in allowed_box_formats: raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format + if iou_type not in allowed_iou_types: + raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") + self.iou_type = iou_type self.iou_thresholds = Tensor(iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1)) self.rec_thresholds = Tensor(rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1)) self.max_detection_thresholds = IntTensor(max_detection_thresholds or [1, 10, 100]) @@ -269,8 +295,10 @@ def __init__( self.add_state("detection_boxes", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) + self.add_state("detection_masks", default=[], dist_reduce_fx=None) self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_masks", default=[], dist_reduce_fx=None) def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore """Add detections and ground truth to the metric. @@ -325,6 +353,8 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] ) self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) + if "masks" in item: + self.detection_masks.append(item["masks"]) for item in target: self.groundtruth_boxes.append( @@ -333,6 +363,8 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] else _fix_empty_tensors(item["boxes"]) ) self.groundtruth_labels.append(item["labels"]) + if "masks" in item: + self.groundtruth_masks.append(item["masks"]) def _get_classes(self) -> List: """Returns a list of unique classes found in ground truth and detection data.""" @@ -341,6 +373,18 @@ def _get_classes(self) -> List: return [] def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor: + if id > len(self.groundtruth_masks): + breakpoint() + if self.iou_type == "segm": + return self._compute_iou_impl(id, self.groundtruth_masks, self.detection_masks, class_id, max_det, segm_iou) + elif self.iou_type == "bbox": + return self._compute_iou_impl(id, self.groundtruth_boxes, self.detection_boxes, class_id, max_det, box_iou) + else: + raise Exception(f"IOU type {self.iou_type} is not supported") + + def _compute_iou_impl( + self, id: int, ground_truths, detections, class_id: int, max_det: int, compute_iou: Callable + ) -> Tensor: """Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given image and class. @@ -352,8 +396,9 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor: max_det: Maximum number of evaluated detection bounding boxes """ - gt = self.groundtruth_boxes[id] - det = self.detection_boxes[id] + gt = ground_truths[id] + det = detections[id] + gt_label_mask = self.groundtruth_labels[id] == class_id det_label_mask = self.detection_labels[id] == class_id if len(gt_label_mask) == 0 or len(det_label_mask) == 0: @@ -371,8 +416,7 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor: if len(det) > max_det: det = det[:max_det] - # generalized_box_iou - ious = box_iou(det, gt) + ious = compute_iou(det, gt) return ious def _evaluate_image(