From 1ffbaff117d7b2198ceff3b9bfea0203a6b489f5 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 26 Apr 2022 11:51:10 +0200 Subject: [PATCH] Apply suggestions from code review --- .../models/auto/feature_extraction_auto.py | 1 + .../models/yolos/configuration_yolos.py | 8 +- .../models/yolos/convert_yolos_to_pytorch.py | 13 +- .../models/yolos/modeling_yolos.py | 264 ++++++++++++++---- tests/yolos/test_modeling_yolos.py | 25 +- 5 files changed, 232 insertions(+), 79 deletions(-) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 79ebbf8015ec7..a8bb2019737fe 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -61,6 +61,7 @@ ("data2vec-vision", "BeitFeatureExtractor"), ("dpt", "DPTFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"), + ("yolos", "DetrFeatureExtractor"), ] ) diff --git a/src/transformers/models/yolos/configuration_yolos.py b/src/transformers/models/yolos/configuration_yolos.py index 24dc8edb0eedb..cd3414a7f26ee 100644 --- a/src/transformers/models/yolos/configuration_yolos.py +++ b/src/transformers/models/yolos/configuration_yolos.py @@ -21,7 +21,7 @@ logger = logging.get_logger(__name__) YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "google/yolos-s": "https://huggingface.co/yolos-s/resolve/main/config.json", + "hustvl/yolos-small": "https://huggingface.co/hustvl/yolos-small/resolve/main/config.json", # See all YOLOS models at https://huggingface.co/models?filter=yolos } @@ -31,7 +31,7 @@ class YolosConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`YolosModel`]. It is used to instantiate a YOLOS model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the YOLOS - [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture. + [hustvl/yolos-base](https://huggingface.co/hustvl/yolos-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -88,10 +88,10 @@ class YolosConfig(PretrainedConfig): ```python >>> from transformers import YolosModel, YolosConfig - >>> # Initializing a YOLOS vit-base-patch16-224 style configuration + >>> # Initializing a YOLOS hustvl/yolos-base style configuration >>> configuration = YolosConfig() - >>> # Initializing a model from the vit-base-patch16-224 style configuration + >>> # Initializing a model from the hustvl/yolos-base style configuration >>> model = YolosModel(configuration) >>> # Accessing the model configuration diff --git a/src/transformers/models/yolos/convert_yolos_to_pytorch.py b/src/transformers/models/yolos/convert_yolos_to_pytorch.py index 036818423828a..a02c5604503e1 100644 --- a/src/transformers/models/yolos/convert_yolos_to_pytorch.py +++ b/src/transformers/models/yolos/convert_yolos_to_pytorch.py @@ -225,9 +225,18 @@ def convert_yolos_checkpoint(yolos_name, checkpoint_path, pytorch_dump_folder_pa feature_extractor.save_pretrained(pytorch_dump_folder_path) if push_to_hub: + model_mapping = { + "yolos_ti": "yolos-tiny", + "yolos_s_200_pre": "yolos-small", + "yolos_s_300_pre": "yolos-small-300", + "yolos_s_dWr": "yolos-small-dwr", + "yolos_base": "yolos-base", + } + print("Pushing to the hub...") - feature_extractor.push_to_hub("nielsr/yolos-s") - model.push_to_hub("nielsr/yolos-s") + model_name = model_mapping[yolos_name] + feature_extractor.push_to_hub(model_name, organization="hustvl") + model.push_to_hub(model_name, organization="hustvl") if __name__ == "__main__": diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 2390b5cad8803..4120e54ca0521 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -61,8 +61,7 @@ YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - # TODO update nielsr to organization - "nielsr/yolos-s", + "hustvl/yolos-small", # See all YOLOS models at https://huggingface.co/models?filter=yolos ] @@ -70,7 +69,7 @@ @dataclass class YolosObjectDetectionOutput(ModelOutput): """ - Output type of [`DetrForObjectDetection`]. + Output type of [`YolosForObjectDetection`]. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): @@ -112,7 +111,7 @@ class YolosObjectDetectionOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# copied from transformers.models.vit.modeling_vit.to_2tuple +# Copied from transformers.models.vit.modeling_vit.to_2tuple def to_2tuple(x): if isinstance(x, collections.abc.Iterable): return x @@ -256,7 +255,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -# copied from transformers.models.vit.modeling_vit.ViTSelfAttention +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos class YolosSelfAttention(nn.Module): def __init__(self, config: YolosConfig) -> None: super().__init__() @@ -317,7 +316,7 @@ def forward( return outputs -# copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos class YolosSelfOutput(nn.Module): """ The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the @@ -337,7 +336,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Yolos +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Yolos class YolosAttention(nn.Module): def __init__(self, config: YolosConfig) -> None: super().__init__() @@ -377,7 +376,7 @@ def forward( return outputs -# copied from transformers.models.vit.modeling_vit.ViTIntermediate +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos class YolosIntermediate(nn.Module): def __init__(self, config: YolosConfig) -> None: super().__init__() @@ -395,7 +394,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# copied from transformers.models.vit.modeling_vit.ViTOutput +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Yolos class YolosOutput(nn.Module): def __init__(self, config: YolosConfig) -> None: super().__init__() @@ -411,7 +410,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos class YolosLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -432,7 +431,7 @@ def forward( output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( - self.layernorm_before(hidden_states), # in YOLOS, layernorm is applied before self-attention + self.layernorm_before(hidden_states), # in Yolos, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) @@ -442,7 +441,7 @@ def forward( # first residual connection hidden_states = attention_output + hidden_states - # in YOLOS, layernorm is also applied after self-attention + # in Yolos, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) @@ -764,8 +763,8 @@ def forward( >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50") - >>> model = YolosForObjectDetection.from_pretrained("facebook/detr-resnet-50") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small") + >>> model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small") >>> inputs = feature_extractor(images=image, return_tensors="pt") @@ -850,28 +849,82 @@ def forward( ) -# copied from transformers.models.detr.modeling_detr.DetrLoss with Detr->Yolos +# Copied from transformers.models.detr.modeling_detr.dice_loss +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +# Copied from transformers.models.detr.modeling_detr.DetrLoss with Detr->Yolos class YolosLoss(nn.Module): """ - This class computes the losses for YolosForObjectDetection. The process happens in two steps: 1) we compute - hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched - ground-truth / prediction (supervise class and box) + This class computes the losses for YolosForObjectDetection/YolosForSegmentation. The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each + pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, matcher, num_classes, eos_coef, losses): """ - Parameters: - Create the criterion. A note on the num_classes parameter (copied from original repo in detr.py): "the naming - of the `num_classes` parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + - 1`, where max_obj_id is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, - so we pass `num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you - should pass `num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion + Create the criterion. + + A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes` + parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id + is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass + `num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass + `num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223" - matcher: module able to compute a matching between targets and proposals. num_classes: number of object - categories, omitting the special no-object category. weight_dict: dict containing as key the names of the - losses and as values their relative weight. eos_coef: relative classification weight applied to the - no-object category. losses: list of all the losses to be applied. See get_loss for list of available - losses. + + Parameters: + matcher: module able to compute a matching between targets and proposals. + num_classes: number of object categories, omitting the special no-object category. + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category. + losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes @@ -907,8 +960,9 @@ def loss_labels(self, outputs, targets, indices, num_boxes): @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes): """ - Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not - really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. + + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. """ logits = outputs["logits"] device = logits.device @@ -921,9 +975,10 @@ def loss_cardinality(self, outputs, targets, indices, num_boxes): def loss_boxes(self, outputs, targets, indices, num_boxes): """ - Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must - contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in - format (center_x, center_y, w, h), normalized by the image size. + Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. + + Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes + are expected in format (center_x, center_y, w, h), normalized by the image size. """ if "pred_boxes" not in outputs: raise KeyError("No predicted boxes found in outputs") @@ -942,6 +997,39 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): losses["loss_giou"] = loss_giou.sum() / num_boxes return losses + def loss_masks(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + + Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. + """ + if "pred_masks" not in outputs: + raise KeyError("No predicted masks found in outputs") + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = nn.functional.interpolate( + src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) @@ -959,6 +1047,7 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes): "labels": self.loss_labels, "cardinality": self.loss_cardinality, "boxes": self.loss_boxes, + "masks": self.loss_masks, } if loss not in loss_map: raise ValueError(f"Loss {loss} not supported") @@ -966,10 +1055,11 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes): def forward(self, outputs, targets): """ - Parameters: This performs the loss computation. - outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, - such that len(targets) == batch_size. + + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} @@ -996,6 +1086,9 @@ def forward(self, outputs, targets): for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): indices = self.matcher(auxiliary_outputs, targets) for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) @@ -1003,12 +1096,14 @@ def forward(self, outputs, targets): return losses -# copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos +# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos class YolosMLPPredictionHead(nn.Module): """ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, - height and width of a bounding box w.r.t. an image. Copied from - https://github.com/facebookresearch/detr/blob/master/models/detr.py + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + """ def __init__(self, input_dim, hidden_dim, output_dim, num_layers): @@ -1023,20 +1118,23 @@ def forward(self, x): return x -# copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos +# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos class YolosHungarianMatcher(nn.Module): """ - This class computes an assignment between the targets and the predictions of the network. For efficiency reasons, - the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In - this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as - non-objects). + This class computes an assignment between the targets and the predictions of the network. + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more + predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). """ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): """ - Params: Creates the matcher. - class_cost: This is the relative weight of the classification error in the matching cost bbox_cost: + + Params: + class_cost: This is the relative weight of the classification error in the matching cost + bbox_cost: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost giou_cost: This is the relative weight of the giou loss of the bounding box in the matching cost """ @@ -1053,8 +1151,9 @@ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float @torch.no_grad() def forward(self, outputs, targets): """ - Params: Performs the matching. + + Params: outputs: This is a dict that contains at least these entries: "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates @@ -1062,8 +1161,10 @@ def forward(self, outputs, targets): "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) @@ -1098,7 +1199,7 @@ def forward(self, outputs, targets): return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] -# copied from transformers.models.detr.modeling_detr._upcast +# Copied from transformers.models.detr.modeling_detr._upcast def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type if t.is_floating_point(): @@ -1107,13 +1208,15 @@ def _upcast(t: Tensor) -> Tensor: return t if t.dtype in (torch.int32, torch.int64) else t.int() -# copied from transformers.models.detr.modeling_detr.box_area +# Copied from transformers.models.detr.modeling_detr.box_area def box_area(boxes: Tensor) -> Tensor: """ - Args: Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: boxes (Tensor[N, 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 < x2` and `0 <= y1 < y2`. + Returns: area (Tensor[N]): area for each box """ @@ -1121,7 +1224,7 @@ def box_area(boxes: Tensor) -> Tensor: return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) -# copied from transformers.models.detr.modeling_detr.box_iou +# Copied from transformers.models.detr.modeling_detr.box_iou def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) @@ -1138,10 +1241,12 @@ def box_iou(boxes1, boxes2): return iou, union -# copied from transformers.models.detr.modeling_detr.generalized_box_iou +# Copied from transformers.models.detr.modeling_detr.generalized_box_iou def generalized_box_iou(boxes1, boxes2): """ - Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. Returns: + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results @@ -1157,3 +1262,54 @@ def generalized_box_iou(boxes1, boxes2): area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area + + +# Copied from transformers.models.detr.modeling_detr._max_by_axis +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +# Copied from transformers.models.detr.modeling_detr.NestedTensor +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + if tensor_list[0].ndim == 3: + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("Only 3-dimensional tensors are supported") + return NestedTensor(tensor, mask) diff --git a/tests/yolos/test_modeling_yolos.py b/tests/yolos/test_modeling_yolos.py index 773e19999e6d4..e64795b1ea011 100644 --- a/tests/yolos/test_modeling_yolos.py +++ b/tests/yolos/test_modeling_yolos.py @@ -31,7 +31,7 @@ from torch import nn from transformers import YolosForObjectDetection, YolosModel - from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple if is_vision_available(): @@ -156,11 +156,7 @@ def create_and_check_for_object_detection(self, config, pixel_values, labels): def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - ( - config, - pixel_values, - labels, - ) = config_and_inputs + config, pixel_values, labels = config_and_inputs inputs_dict = {"pixel_values": pixel_values} return config, inputs_dict @@ -172,14 +168,7 @@ class YolosModelTest(ModelTesterMixin, unittest.TestCase): attention_mask and seq_length. """ - all_model_classes = ( - ( - YolosModel, - YolosForObjectDetection, - ) - if is_torch_available() - else () - ) + all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else () test_pruning = False test_resize_embeddings = False @@ -339,7 +328,7 @@ def test_for_object_detection(self): @slow def test_model_from_pretrained(self): - for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = YolosModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -355,13 +344,11 @@ def prepare_img(): class YolosModelIntegrationTest(unittest.TestCase): @cached_property def default_feature_extractor(self): - # TODO rename nielsr to organization - return AutoFeatureExtractor.from_pretrained("nielsr/yolos-s") if is_vision_available() else None + return AutoFeatureExtractor.from_pretrained("hustvl/yolos-small") if is_vision_available() else None @slow def test_inference_object_detection_head(self): - # TODO rename nielsr to organization - model = YolosForObjectDetection.from_pretrained("nielsr/yolos-s").to(torch_device) + model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small").to(torch_device) feature_extractor = self.default_feature_extractor image = prepare_img()