Skip to content

Commit

Permalink
[fix] Add DeformableDetrFeatureExtractor (huggingface#19140)
Browse files Browse the repository at this point in the history
* Add DeformableDetrFeatureExtractor

* Fix post_process

* Fix name

* Add tests for feature extractor

* Fix doc tests

* Fix name

* Address comments

* Apply same fix to DETR and YOLOS as well

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
2 people authored and oneraghavan committed Sep 26, 2022
1 parent 1907023 commit e29a559
Show file tree
Hide file tree
Showing 13 changed files with 1,380 additions and 53 deletions.
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/deformable_detr.mdx
Expand Up @@ -33,6 +33,16 @@ alt="drawing" width="600"/>

This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/fundamentalvision/Deformable-DETR).

## DeformableDetrFeatureExtractor

[[autodoc]] DeformableDetrFeatureExtractor
- __call__
- pad_and_create_pixel_mask
- post_process
- post_process_segmentation
- post_process_panoptic


## DeformableDetrConfig

[[autodoc]] DeformableDetrConfig
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -659,6 +659,7 @@
_import_structure["models.beit"].append("BeitFeatureExtractor")
_import_structure["models.clip"].append("CLIPFeatureExtractor")
_import_structure["models.convnext"].append("ConvNextFeatureExtractor")
_import_structure["models.deformable_detr"].append("DeformableDetrFeatureExtractor")
_import_structure["models.deit"].append("DeiTFeatureExtractor")
_import_structure["models.detr"].append("DetrFeatureExtractor")
_import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor")
Expand Down Expand Up @@ -3512,6 +3513,7 @@
from .models.clip import CLIPFeatureExtractor
from .models.conditional_detr import ConditionalDetrFeatureExtractor
from .models.convnext import ConvNextFeatureExtractor
from .models.deformable_detr import DeformableDetrFeatureExtractor
from .models.deit import DeiTFeatureExtractor
from .models.detr import DetrFeatureExtractor
from .models.donut import DonutFeatureExtractor
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/feature_extraction_auto.py
Expand Up @@ -44,7 +44,7 @@
("cvt", "ConvNextFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deformable_detr", "DetrFeatureExtractor"),
("deformable_detr", "DeformableDetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("donut", "DonutFeatureExtractor"),
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/deformable_detr/__init__.py
Expand Up @@ -18,13 +18,21 @@

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available


_import_structure = {
"configuration_deformable_detr": ["DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeformableDetrConfig"],
}

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_deformable_detr"] = ["DeformableDetrFeatureExtractor"]

try:
if not is_timm_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -42,6 +50,14 @@
if TYPE_CHECKING:
from .configuration_deformable_detr import DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DeformableDetrConfig

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_deformable_detr import DeformableDetrFeatureExtractor

try:
if not is_timm_available():
raise OptionalDependencyNotAvailable()
Expand Down
Expand Up @@ -24,7 +24,7 @@

import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import DeformableDetrConfig, DeformableDetrForObjectDetection, DetrFeatureExtractor
from transformers import DeformableDetrConfig, DeformableDetrFeatureExtractor, DeformableDetrForObjectDetection
from transformers.utils import logging


Expand Down Expand Up @@ -116,7 +116,7 @@ def convert_deformable_detr_checkpoint(
config.label2id = {v: k for k, v in id2label.items()}

# load feature extractor
feature_extractor = DetrFeatureExtractor(format="coco_detection")
feature_extractor = DeformableDetrFeatureExtractor(format="coco_detection")

# prepare image
img = prepare_img()
Expand Down

0 comments on commit e29a559

Please sign in to comment.