Skip to content

Commit

Permalink
Update copies
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Oct 18, 2022
1 parent c73afd6 commit 496f9eb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 37 deletions.
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
Expand All @@ -44,9 +43,6 @@
if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_vision_available():
from .feature_extraction_conditional_detr import center_to_corners_format

if is_timm_available():
from timm import create_model

Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/detr/modeling_detr.py
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
Expand All @@ -44,9 +43,6 @@
if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_vision_available():
from .feature_extraction_detr import center_to_corners_format

if is_timm_available():
from timm import create_model

Expand Down
Expand Up @@ -1557,16 +1557,16 @@ def loss_labels(self, outputs, targets, indices, num_boxes):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]

idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o

loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}

return losses
Expand Down Expand Up @@ -1596,17 +1596,17 @@ def loss_boxes(self, outputs, targets, indices, num_boxes):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")

losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes

loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
Expand All @@ -1620,41 +1620,41 @@ def loss_masks(self, outputs, targets, indices, num_boxes):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")

src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]

# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)

target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses

def _get_src_permutation_idx(self, indices):
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx

def _get_tgt_permutation_idx(self, indices):
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx

def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
Expand All @@ -1675,7 +1675,7 @@ def forward(self, outputs, targets):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
Expand Down

0 comments on commit 496f9eb

Please sign in to comment.