diff --git a/torchmetrics/detection/mean_ap.py b/torchmetrics/detection/mean_ap.py index 557135aaad4..d81e9556904 100644 --- a/torchmetrics/detection/mean_ap.py +++ b/torchmetrics/detection/mean_ap.py @@ -482,9 +482,11 @@ def __evaluate_image_gt_no_preds( ) -> Dict[str, Any]: """Some GT but no predictions.""" # GTs - gt = gt[gt_label_mask] + gt = np.asarray(gt)[gt_label_mask] + if len(gt_label_mask) == 1: + gt = np.expand_dims(gt, axis=0) nb_gt = len(gt) - areas = box_area(gt) + areas = compute_area(gt, iou_type=self.iou_type).to(self.device) ignore_area = (areas < area_range[0]) | (areas > area_range[1]) gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8)) gt_ignore = gt_ignore.to(torch.bool)