Skip to content

Commit

Permalink
[ViT, BEiT, DeiT, DPT] Improve code (huggingface#16799)
Browse files Browse the repository at this point in the history
* Improve code

* Fix bugs

* Fix another bug

* Clean up DTP as well

* Update DPT model outputs

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
2 people authored and elusenji committed Jun 12, 2022
1 parent 347b318 commit 86a58b0
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 226 deletions.
9 changes: 5 additions & 4 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -662,7 +663,7 @@ def forward(
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedLMOutput(
Expand Down Expand Up @@ -775,7 +776,7 @@ def forward(
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
Expand Down Expand Up @@ -882,7 +883,7 @@ def forward(
logits = (cls_logits + distillation_logits) / 2

if not return_dict:
output = (logits, cls_logits, distillation_logits) + outputs[2:]
output = (logits, cls_logits, distillation_logits) + outputs[1:]
return output

return DeiTForImageClassificationWithTeacherOutput(
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -938,7 +939,7 @@ def forward(
return_dict=return_dict,
)

hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]

# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
Expand All @@ -956,9 +957,9 @@ def forward(

if not return_dict:
if output_hidden_states:
output = (predicted_depth,) + outputs[2:]
output = (predicted_depth,) + outputs[1:]
else:
output = (predicted_depth,) + outputs[3:]
output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return DepthEstimatorOutput(
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(
return_dict=return_dict,
)

hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]

# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
Expand Down Expand Up @@ -1120,9 +1121,9 @@ def forward(

if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[3:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SemanticSegmenterOutput(
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -706,7 +707,7 @@ def forward(
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedLMOutput(
Expand Down Expand Up @@ -798,8 +799,9 @@ def forward(
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
Expand Down
45 changes: 16 additions & 29 deletions tests/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
BeitForSemanticSegmentation,
BeitModel,
)
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST


if is_vision_available():
Expand Down Expand Up @@ -96,6 +96,10 @@ def __init__(
self.out_indices = out_indices
self.num_labels = num_labels

# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

Expand Down Expand Up @@ -132,22 +136,16 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)

def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size
Expand Down Expand Up @@ -312,16 +310,8 @@ def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

# in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
# BEiT has a different seq_length
seq_len = self.model_tester.expected_seq_length

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
Expand All @@ -332,7 +322,7 @@ def test_attention_outputs(self):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

# check that output_attentions also work using config
Expand All @@ -349,7 +339,7 @@ def test_attention_outputs(self):

self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)

Expand All @@ -369,7 +359,7 @@ def test_attention_outputs(self):
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)

def test_hidden_states_output(self):
Expand All @@ -381,18 +371,15 @@ def check_hidden_states_output(inputs_dict, config, model_class):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))

hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states

expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)

# BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length

self.assertListEqual(
list(hidden_states[0].shape[-2:]),
Expand Down
28 changes: 11 additions & 17 deletions tests/beit/test_modeling_flax_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range

# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

Expand Down Expand Up @@ -104,20 +108,14 @@ def create_and_check_model(self, config, pixel_values, labels):

model = FlaxBeitModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)

def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values)
# expected sequence length = num_patches
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
Expand Down Expand Up @@ -151,13 +149,11 @@ def test_config(self):
self.config_tester.run_common_tests()

# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
Expand Down Expand Up @@ -209,7 +205,7 @@ def test_forward_signature(self):
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)

# We neeed to override this test because Beit expects pixel_values instead of input_ids
# We need to override this test because Beit expects pixel_values instead of input_ids
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand All @@ -234,12 +230,10 @@ def model_jitted(pixel_values, **kwargs):
self.assertEqual(jitted_output.shape, output.shape)

# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token
seq_length = self.model_tester.expected_seq_length

outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
Expand Down

0 comments on commit 86a58b0

Please sign in to comment.