Skip to content

Commit

Permalink
implementaed IOU with segmentation masks and MAP for instance segment…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
gianscarpe committed Jan 31, 2022
1 parent 1663629 commit 214e175
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions torchmetrics/detection/map.py
Expand Up @@ -80,6 +80,27 @@ class COCOMetricResults(BaseMetricResults):
)


def segm_iou(inputs, targets, smooth=1):

n_inputs = inputs.shape[0]
n_targets = targets.shape[0]
# flatten label and prediction tensors
inputs = inputs.view(n_inputs, -1).repeat_interleave(n_targets, 0)
targets = targets.view(n_targets, -1).repeat(n_inputs, 1)

# i1 * t1
# i1 * t2
# i2 * t1
# i2 * t2

# intersection is equivalent to True Positive count
# union is the mutually inclusive area of all labels & predictions
intersections = (inputs * targets).sum(1, keepdims=True)
unions = (inputs + targets).sum(1, keepdims=True)

return ((intersections + smooth) / (unions + smooth)).view(n_inputs, n_targets)


def _input_validator(preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]]) -> None:
"""Ensure the correct input format of `preds` and `targets`"""
if not isinstance(preds, Sequence):
Expand Down Expand Up @@ -225,6 +246,7 @@ class MeanAveragePrecision(Metric):
def __init__(
self,
box_format: str = "xyxy",
iou_type: str = "bbox",
iou_thresholds: Optional[List[float]] = None,
rec_thresholds: Optional[List[float]] = None,
max_detection_thresholds: Optional[List[int]] = None,
Expand All @@ -248,9 +270,13 @@ def __init__(
)

allowed_box_formats = ("xyxy", "xywh", "cxcywh")
allowed_iou_types = ("segm", "bbox")
if box_format not in allowed_box_formats:
raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}")
self.box_format = box_format
if iou_type not in allowed_iou_types:
raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}")
self.iou_type = iou_type
self.iou_thresholds = Tensor(iou_thresholds or torch.linspace(0.5, 0.95, round((0.95 - 0.5) / 0.05) + 1))
self.rec_thresholds = Tensor(rec_thresholds or torch.linspace(0.0, 1.00, round(1.00 / 0.01) + 1))
self.max_detection_thresholds = IntTensor(max_detection_thresholds or [1, 10, 100])
Expand All @@ -269,8 +295,10 @@ def __init__(
self.add_state("detection_boxes", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
self.add_state("detection_labels", default=[], dist_reduce_fx=None)
self.add_state("detection_masks", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None)
self.add_state("groundtruth_masks", default=[], dist_reduce_fx=None)

def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore
"""Add detections and ground truth to the metric.
Expand Down Expand Up @@ -325,6 +353,8 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
)
self.detection_labels.append(item["labels"])
self.detection_scores.append(item["scores"])
if "masks" in item:
self.detection_masks.append(item["masks"])

for item in target:
self.groundtruth_boxes.append(
Expand All @@ -333,6 +363,8 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
else _fix_empty_tensors(item["boxes"])
)
self.groundtruth_labels.append(item["labels"])
if "masks" in item:
self.groundtruth_masks.append(item["masks"])

def _get_classes(self) -> List:
"""Returns a list of unique classes found in ground truth and detection data."""
Expand All @@ -341,6 +373,18 @@ def _get_classes(self) -> List:
return []

def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
if id > len(self.groundtruth_masks):
breakpoint()
if self.iou_type == "segm":
return self._compute_iou_impl(id, self.groundtruth_masks, self.detection_masks, class_id, max_det, segm_iou)
elif self.iou_type == "bbox":
return self._compute_iou_impl(id, self.groundtruth_boxes, self.detection_boxes, class_id, max_det, box_iou)
else:
raise Exception(f"IOU type {self.iou_type} is not supported")

def _compute_iou_impl(
self, id: int, ground_truths, detections, class_id: int, max_det: int, compute_iou: Callable
) -> Tensor:
"""Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given
image and class.
Expand All @@ -352,8 +396,9 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
max_det:
Maximum number of evaluated detection bounding boxes
"""
gt = self.groundtruth_boxes[id]
det = self.detection_boxes[id]
gt = ground_truths[id]
det = detections[id]

gt_label_mask = self.groundtruth_labels[id] == class_id
det_label_mask = self.detection_labels[id] == class_id
if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
Expand All @@ -371,8 +416,7 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
if len(det) > max_det:
det = det[:max_det]

# generalized_box_iou
ious = box_iou(det, gt)
ious = compute_iou(det, gt)
return ious

def _evaluate_image(
Expand Down

0 comments on commit 214e175

Please sign in to comment.