Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
NielsRogge and sgugger committed May 18, 2022
1 parent fe257cf commit 8871155
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/cvt/configuration_cvt.py
Expand Up @@ -39,7 +39,7 @@ class CvtConfig(PretrainedConfig):
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
patch_sizes (`List[int]`, *optional*, defaults to [7, 3, 3]):
patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`):
The kernel size of each encoder's patch embedding.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2]):
The stride size of each encoder's patch embedding.
Expand Down
26 changes: 4 additions & 22 deletions src/transformers/models/cvt/modeling_cvt.py
Expand Up @@ -441,11 +441,7 @@ def forward(self, hidden_state, height, width):


class CvtStage(nn.Module):
def __init__(
self,
config,
stage,
):
def __init__(self, config, stage):
super().__init__()
self.config = config
self.stage = stage
Expand Down Expand Up @@ -513,12 +509,7 @@ def __init__(self, config):
for stage_idx in range(len(config.depth)):
self.stages.append(CvtStage(config, stage_idx))

def forward(
self,
pixel_values,
output_hidden_states=False,
return_dict=True,
):
def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
hidden_state = pixel_values

Expand Down Expand Up @@ -625,12 +616,7 @@ class PreTrainedModel
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values=None,
output_hidden_states=None,
return_dict=None,
):
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
Expand Down Expand Up @@ -746,8 +732,4 @@ def forward(
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
return ImageClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)

0 comments on commit 8871155

Please sign in to comment.