diff --git a/src/transformers/models/segformer/configuration_segformer.py b/src/transformers/models/segformer/configuration_segformer.py index d1790634e6e68..bc97dc773e37d 100644 --- a/src/transformers/models/segformer/configuration_segformer.py +++ b/src/transformers/models/segformer/configuration_segformer.py @@ -14,6 +14,8 @@ # limitations under the License. """ SegFormer model configuration""" +import warnings + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -78,9 +80,6 @@ class SegformerConfig(PretrainedConfig): The epsilon used by the layer normalization layers. decoder_hidden_size (`int`, *optional*, defaults to 256): The dimension of the all-MLP decode head. - reshape_last_stage (`bool`, *optional*, defaults to `True`): - Whether to reshape the features of the last stage back to `(batch_size, num_channels, height, width)`. Only - required for the semantic segmentation model. semantic_loss_ignore_index (`int`, *optional*, defaults to 255): The index that is ignored by the loss function of the semantic segmentation model. @@ -122,12 +121,18 @@ def __init__( layer_norm_eps=1e-6, decoder_hidden_size=256, is_encoder_decoder=False, - reshape_last_stage=True, semantic_loss_ignore_index=255, **kwargs ): super().__init__(**kwargs) + if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False: + warnings.warn( + "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be removed, " + "as the behaviour will default to that of reshape_last_stage = True.", + FutureWarning, + ) + self.image_size = image_size self.num_channels = num_channels self.num_encoder_blocks = num_encoder_blocks @@ -147,5 +152,5 @@ def __init__( self.drop_path_rate = drop_path_rate self.layer_norm_eps = layer_norm_eps self.decoder_hidden_size = decoder_hidden_size - self.reshape_last_stage = reshape_last_stage + self.reshape_last_stage = kwargs.get("reshape_last_stage", True) self.semantic_loss_ignore_index = semantic_loss_ignore_index diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 3197678a36c4d..d2f17b267102d 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -45,7 +45,7 @@ # Base docstring _CHECKPOINT_FOR_DOC = "nvidia/mit-b0" -_EXPECTED_OUTPUT_SHAPE = [1, 256, 256] +_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16] # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"