Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ViT, BEiT, DeiT, DPT] Improve code #16799

Merged
merged 5 commits into from
Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -662,7 +663,7 @@ def forward(
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedLMOutput(
Expand Down Expand Up @@ -775,7 +776,7 @@ def forward(
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
Expand Down Expand Up @@ -882,7 +883,7 @@ def forward(
logits = (cls_logits + distillation_logits) / 2

if not return_dict:
output = (logits, cls_logits, distillation_logits) + outputs[2:]
output = (logits, cls_logits, distillation_logits) + outputs[1:]
return output

return DeiTForImageClassificationWithTeacherOutput(
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -938,7 +939,7 @@ def forward(
return_dict=return_dict,
)

hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]

# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
Expand All @@ -956,9 +957,9 @@ def forward(

if not return_dict:
if output_hidden_states:
output = (predicted_depth,) + outputs[2:]
output = (predicted_depth,) + outputs[1:]
else:
output = (predicted_depth,) + outputs[3:]
output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return DepthEstimatorOutput(
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(
return_dict=return_dict,
)

hidden_states = outputs.hidden_states if return_dict else outputs[2]
hidden_states = outputs.hidden_states if return_dict else outputs[1]

# only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings
Expand Down Expand Up @@ -1120,9 +1121,9 @@ def forward(

if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[3:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SemanticSegmenterOutput(
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,8 @@ def forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
Expand Down Expand Up @@ -706,7 +707,7 @@ def forward(
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedLMOutput(
Expand Down Expand Up @@ -798,8 +799,9 @@ def forward(
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
Expand Down
45 changes: 16 additions & 29 deletions tests/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
BeitForSemanticSegmentation,
BeitModel,
)
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST


if is_vision_available():
Expand Down Expand Up @@ -96,6 +96,10 @@ def __init__(
self.out_indices = out_indices
self.num_labels = num_labels

# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

Expand Down Expand Up @@ -132,22 +136,16 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)

def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size
Expand Down Expand Up @@ -312,16 +310,8 @@ def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

# in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
# BEiT has a different seq_length
seq_len = self.model_tester.expected_seq_length

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
Expand All @@ -332,7 +322,7 @@ def test_attention_outputs(self):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

# check that output_attentions also work using config
Expand All @@ -349,7 +339,7 @@ def test_attention_outputs(self):

self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)
out_len = len(outputs)

Expand All @@ -369,7 +359,7 @@ def test_attention_outputs(self):
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, seq_len, seq_len],
)

def test_hidden_states_output(self):
Expand All @@ -381,18 +371,15 @@ def check_hidden_states_output(inputs_dict, config, model_class):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))

hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
hidden_states = outputs.hidden_states

expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)

# BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length

self.assertListEqual(
list(hidden_states[0].shape[-2:]),
Expand Down
28 changes: 11 additions & 17 deletions tests/beit/test_modeling_flax_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range

# in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.expected_seq_length = num_patches + 1

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

Expand Down Expand Up @@ -104,20 +108,14 @@ def create_and_check_model(self, config, pixel_values, labels):

model = FlaxBeitModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
)

def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values)
# expected sequence length = num_patches
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
Expand Down Expand Up @@ -151,13 +149,11 @@ def test_config(self):
self.config_tester.run_common_tests()

# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1
seq_length = self.model_tester.expected_seq_length

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
Expand Down Expand Up @@ -209,7 +205,7 @@ def test_forward_signature(self):
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)

# We neeed to override this test because Beit expects pixel_values instead of input_ids
# We need to override this test because Beit expects pixel_values instead of input_ids
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand All @@ -234,12 +230,10 @@ def model_jitted(pixel_values, **kwargs):
self.assertEqual(jitted_output.shape, output.shape)

# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token
seq_length = self.model_tester.expected_seq_length

outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
Expand Down