Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OWL-ViT model for zero-shot object detection #17938

Merged
merged 87 commits into from Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
bd08fd0
add owlvit model skeleton
alaradirik Jun 16, 2022
cff1597
add class and box predictor heads
alaradirik Jun 17, 2022
3fb93b5
convert modified flax clip to pytorch
alaradirik Jun 21, 2022
6b80535
fix box and class predictors
alaradirik Jun 22, 2022
a57c8c3
add OwlViTImageTextEmbedder
alaradirik Jun 22, 2022
298acc4
convert class and box head checkpoints
alaradirik Jun 23, 2022
aa62cf3
convert image text embedder checkpoints
alaradirik Jun 23, 2022
eed0c47
add object detection head
alaradirik Jun 23, 2022
9dfae2e
fix bugs
alaradirik Jun 27, 2022
12b3554
update conversion script
alaradirik Jun 27, 2022
6e88bdc
update conversion script
alaradirik Jun 27, 2022
d342a81
fix q,v,k,out weight conversion conversion
alaradirik Jun 27, 2022
5a15207
add owlvit object detection output
alaradirik Jun 28, 2022
6adfabd
fix bug in image embedder
alaradirik Jun 28, 2022
ef94525
fix bugs in text embedder
alaradirik Jun 28, 2022
d4315a3
fix positional embeddings
alaradirik Jun 28, 2022
e385e33
fix bug in inference mode vision pooling
alaradirik Jun 29, 2022
985025e
update docs, init tokenizer and processor files
alaradirik Jun 29, 2022
6653465
support batch processing
alaradirik Jun 30, 2022
5e6e8b4
add OwlViTProcessor
alaradirik Jun 30, 2022
2e63dde
remove merge conflicts
alaradirik Jul 1, 2022
79083c5
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 1, 2022
35f9f31
readd owlvit imports
alaradirik Jul 1, 2022
78b7837
fix bug in OwlViTProcessor imports
alaradirik Jul 1, 2022
d919422
fix bugs in processor
alaradirik Jul 1, 2022
4635688
update docs
alaradirik Jul 1, 2022
8a1c825
fix bugs in processor
alaradirik Jul 1, 2022
363f4d5
update owlvit docs
alaradirik Jul 1, 2022
161cb2a
add OwlViTFeatureExtractor
alaradirik Jul 1, 2022
58aa6ce
style changes, add postprocess method to feature extractor
alaradirik Jul 4, 2022
37e3281
add feature extractor and processor tests
alaradirik Jul 4, 2022
261ed39
add object detection tests
alaradirik Jul 4, 2022
cf0591c
update conversion script
alaradirik Jul 5, 2022
02f3a00
update config paths
alaradirik Jul 5, 2022
ab0be98
update config paths
alaradirik Jul 5, 2022
2b215f5
fix configuration paths and bugs
alaradirik Jul 5, 2022
f97d3de
fix bugs in OwlViT tests
alaradirik Jul 5, 2022
1949b63
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 5, 2022
8680f13
add import checks to processor
alaradirik Jul 5, 2022
e6f51de
fix docs and minor issues
alaradirik Jul 6, 2022
e15988d
fix docs and minor issues
alaradirik Jul 6, 2022
b73a66d
fix bugs and issues
alaradirik Jul 7, 2022
68dd41d
fix bugs and issues
alaradirik Jul 7, 2022
11d5928
fix bugs and issues
alaradirik Jul 7, 2022
cef935d
fix bugs and issues
alaradirik Jul 8, 2022
34069b0
update docs and examples
alaradirik Jul 8, 2022
c4aa766
fix bugs and issues
alaradirik Jul 8, 2022
40a6504
update conversion script, fix positional embeddings
alaradirik Jul 8, 2022
9ce1942
process 2D input ids, update tests
alaradirik Jul 11, 2022
b330dfa
fix style and quality issues
alaradirik Jul 11, 2022
051aea6
update docs
alaradirik Jul 11, 2022
bf903f9
update docs and imports
alaradirik Jul 11, 2022
3592af5
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 11, 2022
60749fe
update OWL-ViT index.md
alaradirik Jul 11, 2022
ee007d6
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 12, 2022
6f1aa2d
fix bug in OwlViT feature ext tests
alaradirik Jul 12, 2022
6af7248
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 12, 2022
ba03dbf
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 12, 2022
865510c
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 13, 2022
df9313d
fix code examples, return_dict by default
alaradirik Jul 13, 2022
57d1b68
return_dict by default
alaradirik Jul 13, 2022
253af8b
minor fixes, add tests to processor
alaradirik Jul 13, 2022
3e180da
small fixes
alaradirik Jul 13, 2022
43c04af
add output_attentions arg to main model
alaradirik Jul 13, 2022
efc1ad3
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 13, 2022
8ceea4e
fix bugs
alaradirik Jul 13, 2022
4d416fe
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 14, 2022
4099199
remove output_hidden_states arg from main model
alaradirik Jul 14, 2022
e73b129
update self.config variables
alaradirik Jul 14, 2022
0f3d56f
add option to return last_hidden_states
alaradirik Jul 14, 2022
47c55ea
fix bug in config variables
alaradirik Jul 14, 2022
db70aee
fix copied from statements
alaradirik Jul 14, 2022
ea1452b
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 20, 2022
456bbb3
fix small issues and bugs
alaradirik Jul 20, 2022
c6cd321
fix bugs
alaradirik Jul 20, 2022
57c2cb8
fix bugs, support greyscale images
alaradirik Jul 21, 2022
7ba2c41
run fixup
alaradirik Jul 21, 2022
8c560cb
update repo name
alaradirik Jul 21, 2022
ef2b4f5
merge OwlViTImageTextEmbedder with obj detection head
alaradirik Jul 21, 2022
dfbc6b5
fix merge conflict
alaradirik Jul 21, 2022
27a5ce5
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 21, 2022
405685a
fix merge conflict
alaradirik Jul 21, 2022
a66a879
make fixup
alaradirik Jul 21, 2022
32525bd
fix bugs
alaradirik Jul 22, 2022
1f931eb
fix bugs
alaradirik Jul 22, 2022
1867147
Merge branch 'huggingface:main' into owlvit
alaradirik Jul 22, 2022
75e5ccf
add additional processor test
alaradirik Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -182,8 +182,8 @@ def copy_class_merge_token(hf_model, flax_params):

weight = torch.from_numpy(flax_class_token_params["scale"])
bias = torch.from_numpy(flax_class_token_params["bias"])
hf_model.embedder.layer_norm.weight = nn.Parameter(weight)
hf_model.embedder.layer_norm.bias = nn.Parameter(bias)
hf_model.layer_norm.weight = nn.Parameter(weight)
hf_model.layer_norm.bias = nn.Parameter(bias)


def copy_class_box_heads(hf_model, flax_params):
Expand Down Expand Up @@ -344,7 +344,7 @@ def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dum
hf_backbone.logit_scale = pt_backbone.logit_scale
copy_flax_attn_params(hf_backbone, attn_params)

hf_model.embedder.clip = hf_backbone
hf_model.owlvit = hf_backbone
copy_class_merge_token(hf_model, flax_params)
copy_class_box_heads(hf_model, flax_params)

Expand Down
106 changes: 36 additions & 70 deletions src/transformers/models/owlvit/modeling_owlvit.py
Expand Up @@ -1150,56 +1150,18 @@ def forward(
return (pred_logits, image_class_embeds)


class OwlViTImageTextEmbedder(nn.Module):
def __init__(self, config: OwlViTConfig):
super().__init__()

self.clip = OwlViTModel(config)
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size)

def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:

image_embeds, text_embeds = None, None

# Encode text
if input_ids is not None:
text_embeds = self.clip.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)

# Encode image
if pixel_values is not None:
image_embeds = self.clip.get_image_features(
pixel_values, return_projected=False, output_attentions=output_attentions
)

# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)

# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)

return (image_embeds, text_embeds)


class OwlViTForObjectDetection(OwlViTPreTrainedModel):
config_class = OwlViTConfig
main_input_name = "pixel_values"

def __init__(self, config: OwlViTConfig):
super().__init__(config)

self.embedder = OwlViTImageTextEmbedder(config)
self.owlvit = OwlViTModel(config)
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
self.class_head = OwlViTClassPredictionHead(config)
self.box_head = OwlViTBoxPredictionHead(config)

self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size)
self.sigmoid = nn.Sigmoid()

def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
Expand Down Expand Up @@ -1246,9 +1208,9 @@ def box_predictor(
"""
Args:
image_feats:
Features extracted from the image, returned by the`embedder` function.
Features extracted from the image, returned by the `image_text_embedder` method.
feature_map:
A spatial re-arrangement of image_features, also returned by the `embedder` function.
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
Returns:
pred_boxes:
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
Expand All @@ -1270,7 +1232,7 @@ def class_predictor(
"""
Args:
image_feats:
Features extracted from the image embedder.
Features extracted from the `image_text_embedder`.
query_embeds:
Text query embeddings.
query_mask:
Expand All @@ -1280,13 +1242,30 @@ def class_predictor(

return (pred_logits, image_class_embeds)

def image_embedder(
def image_text_embedder(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.FloatTensor:
# Returns a 2D map of image features.
(image_embeds, _) = self.embedder(pixel_values=pixel_values, output_attentions=output_attentions)
# Encode text
text_embeds = self.owlvit.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)

# Encode image
image_embeds = self.owlvit.get_image_features(
pixel_values, return_projected=False, output_attentions=output_attentions
)

# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)

# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)

# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
Expand All @@ -1297,21 +1276,7 @@ def image_embedder(
)
image_embeds = image_embeds.reshape(new_size)

return image_embeds

def text_embedder(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.FloatTensor:

# Returns text embeddings
(_, text_feats) = self.embedder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)

return text_feats
return (image_embeds, text_embeds)

@add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
Expand Down Expand Up @@ -1362,7 +1327,7 @@ def forward(
vision_model_last_hidden_states = None

if output_hidden_states:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that if a user specifies output_hidden_states, the input_ids and pixel_values are forwarded twice through the model?

outputs = self.embedder.clip(
outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -1373,16 +1338,17 @@ def forward(
text_model_last_hidden_states = outputs[-2][0]
vision_model_last_hidden_states = outputs[-1][0]

# Embed images
feature_map = self.image_embedder(pixel_values=pixel_values, output_attentions=output_attentions)
# Embed images and text queries
feature_map, query_embeds = self.image_text_embedder(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions
)

batch_size, height, width, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, height * width, hidden_dim))

# Embed text queries
query_embeds = self.text_embedder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)

# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
Expand Down