From 7b42d16f7bd837c6a85056a7ec86333340b19070 Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Sat, 5 Feb 2022 18:49:40 +0530 Subject: [PATCH 01/17] Added all files, PoolFormerFeatureExtractor still failing tests --- docs/source/model_doc/poolformer.mdx | 47 ++ src/transformers/__init__.py | 19 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/poolformer/__init__.py | 66 +++ .../poolformer/configuration_poolformer.py | 132 +++++ .../convert_poolformer_timm_to_pytorch.py | 225 ++++++++ .../feature_extraction_poolformer.py | 166 ++++++ .../models/poolformer/modeling_poolformer.py | 527 ++++++++++++++++++ tests/test_feature_extraction_poolformer.py | 234 ++++++++ tests/test_modeling_poolformer.py | 357 ++++++++++++ utils/check_repo.py | 6 + 13 files changed, 1791 insertions(+) create mode 100644 docs/source/model_doc/poolformer.mdx create mode 100644 src/transformers/models/poolformer/__init__.py create mode 100644 src/transformers/models/poolformer/configuration_poolformer.py create mode 100644 src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py create mode 100644 src/transformers/models/poolformer/feature_extraction_poolformer.py create mode 100755 src/transformers/models/poolformer/modeling_poolformer.py create mode 100644 tests/test_feature_extraction_poolformer.py create mode 100644 tests/test_modeling_poolformer.py diff --git a/docs/source/model_doc/poolformer.mdx b/docs/source/model_doc/poolformer.mdx new file mode 100644 index 0000000000000..19a8a6740748d --- /dev/null +++ b/docs/source/model_doc/poolformer.mdx @@ -0,0 +1,47 @@ + + +# PoolFormer + +## Overview + +The PoolFormer model was proposed in []() by . + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](). The original code can be found [here](). + +## PoolFormerConfig + +[[autodoc]] PoolFormerConfig + + +## PoolFormerModel + +[[autodoc]] PoolFormerModel + - forward + +## PoolFormerFeatureExtractor +[[autodoc]] PoolFormerFeatureExtractor + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + +[[autodoc]] PoolFormerForImageClassification + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ad05486104ee7..51f4f3e3185f2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -150,6 +150,7 @@ "load_tf2_weights_in_pytorch_model", ], # Models + "models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models": [], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -512,6 +513,7 @@ # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.poolformer"].append("PoolerFeatureExtractor") _import_structure["models.beit"].append("BeitFeatureExtractor") _import_structure["models.clip"].append("CLIPFeatureExtractor") _import_structure["models.clip"].append("CLIPProcessor") @@ -646,6 +648,15 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure + + _import_structure["models.poolformer"].extend( + [ + "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "PoolFormerModel", + "PoolFormerPreTrainedModel", + "PoolFormerForImageClassification", + ] + ) _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2385,6 +2396,7 @@ load_tf2_weights_in_pytorch_model, ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig + from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -2691,6 +2703,7 @@ if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin + from .models.poolformer import PoolFormerFeatureExtractor from .models.beit import BeitFeatureExtractor from .models.clip import CLIPFeatureExtractor, CLIPProcessor from .models.convnext import ConvNextFeatureExtractor @@ -2750,6 +2763,12 @@ from .utils.dummy_pytorch_quantization_and_torch_objects import * if is_torch_available(): + + from .models.poolformer import ( + POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + PoolFormerModel, + PoolFormerPreTrainedModel, + ) # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d6cf69197ea96..e080784b870e5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from . import ( + poolformer, albert, auto, bart, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9223123d3432d..1115ffb7a3663 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -30,6 +30,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("poolformer", "PoolFormerConfig"), ("convnext", "ConvNextConfig"), ("yoso", "YosoConfig"), ("swin", "SwinConfig"), @@ -125,6 +126,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here + ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -207,6 +209,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("poolformer", "PoolFormer"), ("convnext", "ConvNext"), ("yoso", "YOSO"), ("swin", "Swin"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fe472cc654eac..31cfaed96ac7d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,6 +28,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("poolformer", "PoolFormerModel"), ("convnext", "ConvNextModel"), ("yoso", "YosoModel"), ("swin", "SwinModel"), @@ -160,6 +161,8 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping + + ("poolformer", "PoolFormerForConditionalGeneration"), ("yoso", "YosoForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"), @@ -214,6 +217,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("poolformer", "PoolFormerForCausalLM"), ("xglm", "XGLMForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), ("trocr", "TrOCRForCausalLM"), @@ -350,6 +354,8 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + + ("poolformer", "PoolFormerForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"), ("led", "LEDForConditionalGeneration"), @@ -378,6 +384,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + ("poolformer", "PoolFormerForSequenceClassification"), ("yoso", "YosoForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), @@ -428,6 +435,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping + ("poolformer", "PoolFormerForQuestionAnswering"), ("yoso", "YosoForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py new file mode 100644 index 0000000000000..18a4c60e38718 --- /dev/null +++ b/src/transformers/models/poolformer/__init__.py @@ -0,0 +1,66 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...file_utils import _LazyModule, is_vision_available +from ...file_utils import is_torch_available + + +_import_structure = { + "configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_poolformer"] = [ + "PoolFormerFeatureExtractor" + ] + +if is_torch_available(): + _import_structure["modeling_poolformer"] = [ + "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "PoolFormerForImageClassification", + "PoolFormerLayer", + "PoolFormerModel", + "PoolFormerPreTrainedModel", + "load_tf_weights_in_poolformer", + ] + + + +if TYPE_CHECKING: + from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig + + if is_vision_available(): + from .feature_extraction_poolformer import PoolFormerFeatureExtractor + + if is_torch_available(): + from .modeling_poolformer import ( + POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + PoolFormerLayer, + PoolFormerModel, + PoolFormerPreTrainedModel, + PoolFormerForImageClassification, + load_tf_weights_in_poolformer, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) \ No newline at end of file diff --git a/src/transformers/models/poolformer/configuration_poolformer.py b/src/transformers/models/poolformer/configuration_poolformer.py new file mode 100644 index 0000000000000..f8b5412c60dde --- /dev/null +++ b/src/transformers/models/poolformer/configuration_poolformer.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright Sea AI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PoolFormer model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "sail/poolformer_s12": "https://huggingface.co/sail/poolformer_s12/resolve/main/config.json", + # See all PoolFormer models at https://huggingface.co/models?filter=poolformer +} + + +class PoolFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.PoolFormerModel`. + It is used to instantiate an PoolFormer model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the PoolFormer `sail/poolformer_s12 `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + num_channels (:obj:`int`, optional, defaults to 3): + The number of channels in the input image. + patch_size (:obj:`int`, optional, defaults to 16): + The size of the input patch. + stride (:obj:`int`, optional, defaults to 16): + The stride of the input patch. + pool_size (:obj:`int`, optional, defaults to 3): + The size of the pooling window. + mlp_ratio (:obj:`float`, optional, defaults to 4.0): + The ratio of the number of channels in the output of the MLP to the number of channels in the input. + depths (:obj:`list`, optional, defaults to [2, 2, 6, 2]): + The depth of each encoder block. + hidden_sizes (:obj:`list`, optional, defaults to [64, 128, 320, 512]): + The hidden sizes of each encoder block. + patch_sizes (:obj:`list`, optional, defaults to [7, 3, 3, 3]): + The size of the input patch for each encoder block. + strides (:obj:`list`, optional, defaults to [4, 2, 2, 2]): + The stride of the input patch for each encoder block. + padding (:obj:`list`, optional, defaults to [2, 1, 1, 1]): + The padding of the input patch for each encoder block. + num_encoder_blocks (:obj:`int`, optional, defaults to 4): + The number of encoder blocks. + drop_path_rate (:obj:`float`, optional, defaults to 0.0): + The dropout rate for the dropout layers. + hidden_act (:obj:`str`, optional, defaults to "gelu"): + The activation function for the hidden layers. + use_layer_scale (:obj:`bool`, optional, defaults to True): + Whether to use layer scale. + layer_scale_init_value (:obj:`float`, optional, defaults to 1e-5): + The initial value for the layer scale. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The initializer range for the weights. + use_cache (:obj:`bool`, optional, defaults to True): + Whether to use cache. + Example:: + + >>> from transformers import PoolFormerModel, PoolFormerConfig + + >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration + >>> configuration = PoolFormerConfig() + + >>> # Initializing a model from the sail/poolformer_s12 style configuration + >>> model = PoolFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "poolformer" + + + def __init__( + self, + num_channels=3, + patch_size=16, + stride=16, + pool_size=3, + mlp_ratio=4.0, + depths=[2, 2, 6, 2], + hidden_sizes=[64, 128, 320, 512], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + padding=[2, 1, 1, 1], + num_encoder_blocks=4, + drop_path_rate=0.0, + hidden_act="gelu", + use_layer_scale=True, + layer_scale_init_value=1e-5, + initializer_range=0.02, + use_cache=True, + **kwargs + ): + self.num_channels = num_channels + self.patch_size = patch_size + self.stride = stride + self.padding = padding + self.pool_size = pool_size + self.hidden_sizes = hidden_sizes + self.mlp_ratio = mlp_ratio + self.depths = depths + self.patch_sizes = patch_sizes + self.strides = strides + self.num_encoder_blocks = num_encoder_blocks + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_layer_scale = use_layer_scale + self.layer_scale_init_value = layer_scale_init_value + self.initializer_range = initializer_range + self.use_cache = use_cache + super().__init__( + **kwargs + ) \ No newline at end of file diff --git a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py new file mode 100644 index 0000000000000..e29c20d86efb6 --- /dev/null +++ b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PoolFormer checkpoints.""" + +import math +import argparse +import json +from collections import OrderedDict +from pathlib import Path +from stat import ST_SIZE + +import torch +from PIL import Image + +import requests +from huggingface_hub import cached_download, hf_hub_url +from transformers import ( + PoolFormerConfig, + PoolFormerForImageClassification, + PoolFormerFeatureExtractor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +def replace_key_with_offset(key, offset, original_name, new_name): + """ + Replaces the key by subtracting the offset from the original layer number + """ + to_find = original_name.split(".")[0] + key_list = key.split(".") + orig_block_num = int(key_list[key_list.index(to_find)-2]) + layer_num = int(key_list[key_list.index(to_find)-1]) + new_block_num = orig_block_num - offset + + key = key.replace( + f"{orig_block_num}.{layer_num}.{original_name}", + f"block.{new_block_num}.{layer_num}.{new_name}" + ) + return key + +def rename_keys(state_dict): + new_state_dict = OrderedDict() + total_embed_found, patch_emb_offset = 0, 0 + for key, value in state_dict.items(): + if key.startswith("network"): + key = key.replace("network", "poolformer.encoder") + if "proj" in key: + # Works for the first embedding as well as the internal embedding layers + if key.endswith("bias") and "patch_embed" not in key: + patch_emb_offset += 1 + to_replace = key[:key.find("proj")] + key = key.replace(to_replace, f"patch_embeddings.{total_embed_found}.") + key = key.replace("proj", "projection") + if key.endswith("bias"): + total_embed_found += 1 + if "patch_embeddings" in key: + key = "poolformer.encoder." + key + if "mlp.fc1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc1", "output.conv1") + if "mlp.fc2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc2", "output.conv2") + if "norm1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "norm1", "before_norm") + if "norm2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "norm2", "after_norm") + if "layer_scale_1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_1", "layer_scale_1") + if "layer_scale_2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_2", "layer_scale_2") + if "head" in key: + key = key.replace("head", "classifier") + new_state_dict[key] = value + return new_state_dict + +# We will verify our results on a COCO image +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + return image + +@torch.no_grad() +def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our PoolFormer structure. + """ + + # load default PoolFormer configuration + config = PoolFormerConfig() + + # set attributes based on model_name + repo_id = "datasets/huggingface/label-files" + size = model_name[-3:] + config.num_labels = 1000 + filename = "imagenet-1k-id2label.json" + expected_shape = (1, 1000) + + # set config attributes + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if size == "s12": + config.depths = [2, 2, 6, 2] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + crop_pcent = 0.9 + elif size == "s24": + config.depths = [4, 4, 12, 4] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + crop_pcent = 0.9 + elif size == "s36": + config.depths = [6, 6, 18, 6] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pcent = 0.9 + elif size == "m36": + config.depths = [6, 6, 18, 6] + config.hidden_sizes = [96, 192, 384, 768] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pcent = 0.95 + elif size == "m48": + config.depths = [8, 8, 24, 8] + config.hidden_sizes = [96, 192, 384, 768] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pcent = 0.95 + else: + raise ValueError(f"Size {size} not supported") + + print(size) + # load feature extractor (only resize + normalize) + img_size = (224, 224) + feature_extractor = DeiTFeatureExtractor( + size=img_size, + resample=Image.BILINEAR, + crop_size=224, + ) + + # Prepare image + image = prepare_img() + pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values + + logger.info(f"Converting model {model_name}...") + + # load original state dict + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) + + # rename keys + state_dict = rename_keys(state_dict) + + # create HuggingFace model and load state dict + model = PoolFormerForImageClassification(config) + model.load_state_dict(state_dict) + model.eval() + + # Define feature extractor + feature_extractor = PoolFormerFeatureExtractor(size=size) + encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + + # forward pass + outputs = model(pixel_values) + logits = outputs.logits + + # define expected logit slices for different models + if size == "s12": + expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]) + elif size == "s24": + expected_slice = torch.tensor([ 0.4402, -0.1374, -0.8045]) + elif size == "s36": + expected_slice = torch.tensor([-0.6080, -0.5133, -0.5898]) + elif size == "m36": + expected_slice = torch.tensor([ 0.3952, 0.2263, -1.2668]) + elif size == "m48": + expected_slice = torch.tensor([ 0.1167, -0.0656, -0.3423]) + else: + raise ValueError(f"Size {size} not supported") + + # verify logits + assert logits.shape == expected_shape + assert torch.allclose(logits[0, :3], expected_slice, atol=1e-2) + + # finally, save model and feature extractor + logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name", + default="poolformer_s12", + type=str, + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + convert_poolformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path) \ No newline at end of file diff --git a/src/transformers/models/poolformer/feature_extraction_poolformer.py b/src/transformers/models/poolformer/feature_extraction_poolformer.py new file mode 100644 index 0000000000000..392d633d06c32 --- /dev/null +++ b/src/transformers/models/poolformer/feature_extraction_poolformer.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for PoolFormer.""" + +from email.policy import default +import math +from typing import Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...file_utils import TensorType +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ImageFeatureExtractionMixin, + ImageInput, + is_torch_tensor, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PoolFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a PoolFormer feature extractor. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + do_resize_and_center_crop (`bool`, *optional*, defaults to `True`): + Whether to resize and center crop the input to a certain `size`. + size (`int` or `Tuple(int)`, *optional*, defaults to 224): + Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an + integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize_and_center_crop` is + set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + crop_pct (`float`, *optional*, defaults to `0.9`): + The percentage of the image to crop from the center. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. + image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize_and_center_crop=True, + size=224, + resample=Image.BILINEAR, + crop_pct=0.9, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize_and_center_crop = do_resize_and_center_crop + self.size = size + self.resample = resample + self.crop_pct = crop_pct + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def __call__( + self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize_and_center_crop and self.size is not None: + if isinstance(self.size, (tuple, list)): + assert len(self.size) == 2 + if self.size[-1] == self.size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(self.size[0] / self.crop_pct)) + else: + scale_size = tuple([int(x / self.crop_pct) for x in self.size]) + else: + scale_size = int(math.floor(self.size / self.crop_pct)) + + images = [self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs \ No newline at end of file diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py new file mode 100755 index 0000000000000..049f92a042ceb --- /dev/null +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -0,0 +1,527 @@ +# coding=utf-8 +# Copyright 2022 Sea AI Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch PoolFormer model. """ + + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.functional import block_diag, norm +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_poolformer import PoolFormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "seaailabs/poolformer_s12" +_CONFIG_FOR_DOC = "PoolFormerConfig" + +POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "sail/poolformer_s12", + # See all PoolFormer models at https://huggingface.co/models?filter=poolformer +] + +# Copied from transformers.models.vit.modeling_vit.to_2tuple +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + +@dataclass +class PoolFormerModelOutput(ModelOutput): + """ + Class for PoolFormer model's outputs, with hidden states. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + """ + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class PoolFormerClassifierOutput(ModelOutput): + """ + Output class for PoolFormer Classifier's Output + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PoolFormerEmbeddings(nn.Module): + """ + Construct Patch Embeddings + """ + def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=patch_size, + stride=stride, + padding=padding + ) + self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity() + + def forward(self, pixel_values): + x = self.projection(pixel_values) + x = self.norm(x) + return x + + +class PoolFormerGroupNorm(nn.GroupNorm): + """ + Group Normalization with 1 group. + Input: tensor in shape [B, C, H, W] + """ + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + +class PoolFormerPooling(nn.Module): + def __init__(self, pool_size): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, + stride=1, + padding=pool_size//2, + count_include_pad=False + ) + + def forward(self, hidden_states): + return self.pool(hidden_states) - hidden_states + + +class PoolFormerOutput(nn.Module): + def __init__(self, config, dropout_prob, hidden_size, intermediate_size): + super().__init__() + self.conv1 = nn.Conv2d(hidden_size, intermediate_size, 1) + self.conv2 = nn.Conv2d(intermediate_size, hidden_size, 1) + self.drop = DropPath(dropout_prob) + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.drop(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.drop(hidden_states) + + return hidden_states + +class PoolFormerLayer(nn.Module): + """This corresponds to the 'PoolFormerBlock' class in the original implementation.""" + + def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_size, drop_path): + super().__init__() + self.pooling = PoolFormerPooling(pool_size) + self.output = PoolFormerOutput(config, drop_path, hidden_size, intermediate_size) + self.before_norm = PoolFormerGroupNorm(num_channels) + self.after_norm = PoolFormerGroupNorm(num_channels) + + # Useful for training neural nets + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.use_layer_scale = config.use_layer_scale + if config.use_layer_scale: + self.layer_scale_1 = nn.Parameter( + config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True + ) + + def forward(self, hidden_states): + if self.use_layer_scale: + pooling_output = self.pooling(self.before_norm(hidden_states)) + scaled_op = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * pooling_output + # First residual connection + hidden_states = hidden_states + self.drop_path(scaled_op) + outputs = () + + layer_output = self.output(self.after_norm(hidden_states)) + scaled_op = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * layer_output + # Second residual connection + output = hidden_states + self.drop_path(scaled_op) + + outputs = (output,) + outputs + return outputs + + else: + pooling_output = self.drop_path(self.pooling(self.before_norm(hidden_states))) + # First residual connection + hidden_states = pooling_output + hidden_states + outputs = () + + # Second residual connection inside the PoolFormerOutput block + layer_output = self.drop_path(self.output(self.after_norm(hidden_states))) + output = hidden_states + layer_output + + outputs = (output,) + outputs + return outputs + + +class PoolFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + + # patch embeddings + embeddings = [] + for i in range(config.num_encoder_blocks): + embeddings.append( + PoolFormerEmbeddings( + patch_size=config.patch_sizes[i], + stride=config.strides[i], + padding=config.padding[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + ) + ) + self.patch_embeddings = nn.ModuleList(embeddings) + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + PoolFormerLayer( + config, + num_channels=config.hidden_sizes[i], + pool_size=config.pool_size, + hidden_size=config.hidden_sizes[i], + intermediate_size=int(config.hidden_sizes[i] * config.mlp_ratio), + drop_path=dpr[cur + j], + ) + ) + blocks.append(nn.ModuleList(layers)) + + self.block = nn.ModuleList(blocks) + + def forward( + self, + pixel_values, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + + hidden_states = pixel_values + for idx, x in enumerate(zip(self.patch_embeddings, self.block)): + embedding_layer, block_layer = x + # Get patch embeddings from hidden_states + hidden_states = embedding_layer(hidden_states) + # Send the embeddings through the blocks + for i, blk in enumerate(block_layer): + layer_outputs = blk(hidden_states) + hidden_states = layer_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return PoolFormerModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class PoolFormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PoolFormerConfig + base_model_prefix = "poolformer" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, PoolFormerEncoder): + module.gradient_checkpointing = value + + +POOLFORMER_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.PoolFormerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +POOLFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`PoolFormerFeatureExtractor`]. See + [`PoolFormerFeatureExtractor.__call__`] for details. +""" + + +@add_start_docstrings( + "The bare PoolFormer Model transformer outputting raw hidden-states without any specific head on top.", + POOLFORMER_START_DOCSTRING, +) +class PoolFormerModel(PoolFormerPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.encoder = PoolFormerEncoder(config) + + self.pooler = None + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PoolFormerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, PoolFormerModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') + >>> model = PoolFormerModel.from_pretrained('seaailabs/poolformer_s12') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return PoolFormerModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class PoolFormerFinalPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states): + output = self.dense(hidden_states) + return output + + +@add_start_docstrings( + """ + PoolFormer Model transformer with an image classification head on top + """, + POOLFORMER_START_DOCSTRING, +) +class PoolFormerForImageClassification(PoolFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.poolformer = PoolFormerModel(config) + + # Final norm + self.norm = PoolFormerGroupNorm(config.hidden_sizes[-1]) + # Classifier head + self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PoolFormerClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + labels=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, PoolFormerForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') + >>> model = PoolFormerForImageClassification.from_pretrained('seaailabs/poolformer_s12') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.poolformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(self.norm(sequence_output).mean([-2, -1])) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PoolFormerClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) \ No newline at end of file diff --git a/tests/test_feature_extraction_poolformer.py b/tests/test_feature_extraction_poolformer.py new file mode 100644 index 0000000000000..f2f55575edf3c --- /dev/null +++ b/tests/test_feature_extraction_poolformer.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import PoolFormerFeatureExtractor + +def calc_cropped_sizes(images, crop_pct=0.9, size=224, pil=False): + """Calculates and returns a list of the expected sizes of all the cropped images. + """ + size = int(math.floor(size / crop_pct)) + image_shapes = [] + for img in images: + if pil: + height, width = img.size + else: + width, height = img.shape[-2], img.shape[-1] + short, long = (width, height) if width <= height else (height, width) + if short == size: + image_shapes.append((width, height)) + else: + new_short, new_long = size, int(size * long / short) + new_size = (new_short, new_long) if width <= height else (new_long, new_short) + image_shapes.append(new_size) + return image_shapes + +class PoolFormerFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + do_resize_and_center_crop=True, + size=224, + crop_pct=0.9, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize_and_center_crop = do_resize_and_center_crop + self.size = size + self.crop_pct = crop_pct + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_feat_extract_dict(self): + return { + "size": self.size, + "do_resize_and_center_crop": self.do_resize_and_center_crop, + "crop_pct": self.crop_pct, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + } + + +@require_torch +@require_vision +class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = PoolFormerFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = PoolFormerFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize_and_center_crop")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "crop_pct")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Calculate the expected sizes of all the images + expected_sizes = calc_cropped_sizes( + image_inputs, + self.feature_extract_tester.crop_pct, + self.feature_extract_tester.size, + pil=True + ) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Calculate the expected sizes of all the images + expected_sizes = calc_cropped_sizes( + image_inputs, + self.feature_extract_tester.crop_pct, + self.feature_extract_tester.size + ) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Calculate the expected sizes of all the images + expected_sizes = calc_cropped_sizes( + image_inputs, + self.feature_extract_tester.crop_pct, + self.feature_extract_tester.size + ) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_sizes[0][0], + expected_sizes[0][1], + ), + ) \ No newline at end of file diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py new file mode 100644 index 0000000000000..4b4a178665f2a --- /dev/null +++ b/tests/test_modeling_poolformer.py @@ -0,0 +1,357 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch PoolFormer model. """ + + +import inspect +import unittest + +from typing import List, Tuple, Dict + +from transformers import is_torch_available, is_vision_available +from transformers.models.auto import get_values +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_MAPPING, + PoolFormerConfig, + PoolFormerForImageClassification, + PoolFormerModel, + ) + from transformers.models.poolformer.modeling_poolformer import POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import DeiTFeatureExtractor + + +class PoolFormerConfigTester(ConfigTester): + def create_and_test_config_common_properties(self): + config = self.config_class(**self.inputs_dict) + self.parent.assertTrue(hasattr(config, "hidden_sizes")) + self.parent.assertTrue(hasattr(config, "num_encoder_blocks")) + + +class PoolFormerModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=64, + num_channels=3, + num_encoder_blocks=4, + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + hidden_sizes=[16, 32, 64, 128], + downsampling_rates=[1, 4, 8, 16], + is_training=False, + use_labels=True, + hidden_act="gelu", + hidden_dropout_prob=0.1, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.num_encoder_blocks = num_encoder_blocks + self.sr_ratios = sr_ratios + self.depths = depths + self.hidden_sizes = hidden_sizes + self.downsampling_rates = downsampling_rates + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.initializer_range = initializer_range + self.num_labels = num_labels + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = PoolFormerConfig( + image_size=self.image_size, + num_channels=self.num_channels, + num_encoder_blocks=self.num_encoder_blocks, + depths=self.depths, + hidden_sizes=self.hidden_sizes, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + initializer_range=self.initializer_range, + ) + + return config, pixel_values, labels + + def create_and_check_model(self, config, pixel_values, labels): + model = PoolFormerModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + expected_height = expected_width = self.image_size // 32.0 + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.hidden_sizes[-1], expected_height, expected_width) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + PoolFormerModel, + PoolFormerForImageClassification, + ) + if is_torch_available() + else () + ) + + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_torchscript = False + + def setUp(self): + self.model_tester = PoolFormerModelTester(self) + self.config_tester = PoolFormerConfigTester(self, config_class=PoolFormerConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip("PoolFormer does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip("PoolFormer does not have get_input_embeddings method and get_output_embeddings methods") + def test_model_common_attributes(self): + pass + + def test_retain_grad_hidden_states_attentions(self): + # Since poolformer doesn't use Attention + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + hidden_states = outputs.hidden_states[0] + + hidden_states.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.", + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @unittest.skip("PoolFormer does not have attention") + def test_attention_outputs(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = self.model_tester.num_encoder_blocks + self.assertEqual(len(hidden_states), expected_num_layers) + + # verify the first hidden states (first block) + self.assertListEqual( + list(hidden_states[0].shape[-3:]), + [ + self.model_tester.hidden_sizes[0], + self.model_tester.image_size // 4, + self.model_tester.image_size // 4, + ], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_training(self): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING): + continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + if model_class.__name__ == "PoolFormerForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros( + [self.model_tester.batch_size, height, width], device=torch_device + ).long() + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + @slow + def test_model_from_pretrained(self): + for model_name in POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = PoolFormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +class PoolFormerModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_image_classification_head(self): + model = PoolFormerForImageClassification.from_pretrained("sail/poolformer_s12").to(torch_device) + + img_size = (224, 224) + feature_extractor = PoolFormerFeatureExtractor( + size=img_size, + ) + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) \ No newline at end of file diff --git a/utils/check_repo.py b/utils/check_repo.py index 9ee2266ca7366..056a9a8abf2f4 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -44,6 +44,9 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested +"PoolFormerEncoder", # Building part of bigger (tested) model. + "PoolFormerDecoder", # Building part of bigger (tested) model. + "PoolFormerDecoderWrapper", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model. @@ -108,6 +111,9 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping +"PoolFormerEncoder", + "PoolFormerDecoder", + "PoolFormerDecoderWrapper", "ViltForQuestionAnswering", "ViltForImagesAndTextClassification", "ViltForImageAndTextRetrieval", From 3f34b12b017cb420871e4710805d030e55f65b8b Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Tue, 8 Feb 2022 11:12:44 +0530 Subject: [PATCH 02/17] Fixed PoolFormerFeatureExtractor not being able to import --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 51f4f3e3185f2..288437b7159c1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -513,7 +513,7 @@ # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] - _import_structure["models.poolformer"].append("PoolerFeatureExtractor") + _import_structure["models.poolformer"].append("PoolFormerFeatureExtractor") _import_structure["models.beit"].append("BeitFeatureExtractor") _import_structure["models.clip"].append("CLIPFeatureExtractor") _import_structure["models.clip"].append("CLIPProcessor") From 0c7463fae3c0d785084cf936a769412b0c227435 Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Tue, 8 Feb 2022 15:06:08 +0530 Subject: [PATCH 03/17] Completed Poolformer doc --- docs/source/model_doc/poolformer.mdx | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/docs/source/model_doc/poolformer.mdx b/docs/source/model_doc/poolformer.mdx index 19a8a6740748d..35422a32db62c 100644 --- a/docs/source/model_doc/poolformer.mdx +++ b/docs/source/model_doc/poolformer.mdx @@ -14,17 +14,17 @@ specific language governing permissions and limitations under the License. ## Overview -The PoolFormer model was proposed in []() by . +The PoolFormer model was proposed in [MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418) by Sea AI Labs. Instead of designing complicated token mixer to achieve SOTA performance, the target of this work is to demonstrate the competence of transformer models largely stem from the general architecture MetaFormer. The abstract from the paper is the following: -** +*Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design.* Tips: -This model was contributed by [INSERT YOUR HF USERNAME HERE](). The original code can be found [here](). +This model was contributed by [heytanay]( Date: Tue, 8 Feb 2022 20:10:48 +0530 Subject: [PATCH 04/17] Applied Suggested fixes --- src/transformers/__init__.py | 34 ++++++++--------- src/transformers/models/__init__.py | 2 +- .../models/poolformer/__init__.py | 4 -- .../poolformer/configuration_poolformer.py | 5 +-- .../convert_poolformer_timm_to_pytorch.py | 2 +- .../models/poolformer/modeling_poolformer.py | 38 +++++++------------ 6 files changed, 34 insertions(+), 51 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 288437b7159c1..5fc467c2c020b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -150,7 +150,6 @@ "load_tf2_weights_in_pytorch_model", ], # Models - "models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models": [], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -264,6 +263,7 @@ "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], "models.phobert": ["PhobertTokenizer"], + "models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], "models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], @@ -649,14 +649,6 @@ # PyTorch models structure - _import_structure["models.poolformer"].extend( - [ - "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", - "PoolFormerModel", - "PoolFormerPreTrainedModel", - "PoolFormerForImageClassification", - ] - ) _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1225,6 +1217,14 @@ "PerceiverPreTrainedModel", ] ) + _import_structure["models.poolformer"].extend( + [ + "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "PoolFormerModel", + "PoolFormerPreTrainedModel", + "PoolFormerForImageClassification", + ] + ) _import_structure["models.prophetnet"].extend( [ "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2396,7 +2396,6 @@ load_tf2_weights_in_pytorch_model, ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig - from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -2493,6 +2492,7 @@ from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer + from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer from .models.phobert import PhobertTokenizer from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer @@ -2703,7 +2703,6 @@ if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin - from .models.poolformer import PoolFormerFeatureExtractor from .models.beit import BeitFeatureExtractor from .models.clip import CLIPFeatureExtractor, CLIPProcessor from .models.convnext import ConvNextFeatureExtractor @@ -2713,6 +2712,7 @@ from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor from .models.layoutxlm import LayoutXLMProcessor from .models.perceiver import PerceiverFeatureExtractor + from .models.poolformer import PoolFormerFeatureExtractor from .models.segformer import SegformerFeatureExtractor from .models.vilt import ViltFeatureExtractor, ViltProcessor from .models.vit import ViTFeatureExtractor @@ -2763,12 +2763,6 @@ from .utils.dummy_pytorch_quantization_and_torch_objects import * if is_torch_available(): - - from .models.poolformer import ( - POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - PoolFormerModel, - PoolFormerPreTrainedModel, - ) # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments @@ -3291,6 +3285,12 @@ PerceiverModel, PerceiverPreTrainedModel, ) + from .models.poolformer import ( + POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + PoolFormerModel, + PoolFormerPreTrainedModel, + PoolFormerForImageClassification, + ) from .models.prophetnet import ( PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, ProphetNetDecoder, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e080784b870e5..50d287c61efcd 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,7 +17,6 @@ # limitations under the License. from . import ( - poolformer, albert, auto, bart, @@ -84,6 +83,7 @@ pegasus, perceiver, phobert, + poolformer, prophetnet, qdqbert, rag, diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py index 18a4c60e38718..2d885814777af 100644 --- a/src/transformers/models/poolformer/__init__.py +++ b/src/transformers/models/poolformer/__init__.py @@ -35,10 +35,8 @@ _import_structure["modeling_poolformer"] = [ "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "PoolFormerForImageClassification", - "PoolFormerLayer", "PoolFormerModel", "PoolFormerPreTrainedModel", - "load_tf_weights_in_poolformer", ] @@ -52,11 +50,9 @@ if is_torch_available(): from .modeling_poolformer import ( POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - PoolFormerLayer, PoolFormerModel, PoolFormerPreTrainedModel, PoolFormerForImageClassification, - load_tf_weights_in_poolformer, ) diff --git a/src/transformers/models/poolformer/configuration_poolformer.py b/src/transformers/models/poolformer/configuration_poolformer.py index f8b5412c60dde..8b8490befa4ff 100644 --- a/src/transformers/models/poolformer/configuration_poolformer.py +++ b/src/transformers/models/poolformer/configuration_poolformer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright Sea AI Labs and The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 Sea AI Labs and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging - logger = logging.get_logger(__name__) POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -29,7 +28,7 @@ class PoolFormerConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a :class:`~transformers.PoolFormerModel`. - It is used to instantiate an PoolFormer model according to the specified arguments, defining the model + It is used to instantiate a PoolFormer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the PoolFormer `sail/poolformer_s12 `__ architecture. diff --git a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py index e29c20d86efb6..fb8e95fd13987 100644 --- a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py +++ b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert PoolFormer checkpoints.""" +"""Convert PoolFormer checkpoints from the original repository. URL: https://github.com/sail-sg/poolformer""" import math import argparse diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 049f92a042ceb..5160a734fc840 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -52,7 +52,7 @@ def to_2tuple(x): @dataclass class PoolFormerModelOutput(ModelOutput): """ - Class for PoolFormer model's outputs, with hidden states. + Class for PoolFormerModel's outputs, with potential hidden states. Args: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): @@ -69,7 +69,7 @@ class PoolFormerModelOutput(ModelOutput): @dataclass class PoolFormerClassifierOutput(ModelOutput): """ - Output class for PoolFormer Classifier's Output + Class for PoolformerForImageClassification's outputs. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): @@ -78,7 +78,7 @@ class PoolFormerClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. + of shape :obj:`(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. """ @@ -104,7 +104,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): return output -class DropPath(nn.Module): +class PoolFormerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): @@ -168,7 +168,7 @@ def __init__(self, config, dropout_prob, hidden_size, intermediate_size): super().__init__() self.conv1 = nn.Conv2d(hidden_size, intermediate_size, 1) self.conv2 = nn.Conv2d(intermediate_size, hidden_size, 1) - self.drop = DropPath(dropout_prob) + self.drop = PoolFormerDropPath(dropout_prob) if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: @@ -194,7 +194,7 @@ def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_si self.after_norm = PoolFormerGroupNorm(num_channels) # Useful for training neural nets - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0. else nn.Identity() self.use_layer_scale = config.use_layer_scale if config.use_layer_scale: self.layer_scale_1 = nn.Parameter( @@ -325,10 +325,6 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -362,42 +358,35 @@ def _set_gradient_checkpointing(self, module, value=False): POOLFORMER_START_DOCSTRING, ) class PoolFormerModel(PoolFormerPreTrainedModel): - def __init__(self, config, add_pooling_layer=True): + def __init__(self, config): super().__init__(config) self.config = config self.encoder = PoolFormerEncoder(config) - self.pooler = None # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings - # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PoolFormerModelOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - pixel_values=None, - output_hidden_states=None, - return_dict=None, - ): + def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): r""" Returns: Examples:: - >>> from transformers import ViTFeatureExtractor, PoolFormerModel + >>> from transformers import PoolFormerFeatureExtractor, PoolFormerModel >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) - >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') - >>> model = PoolFormerModel.from_pretrained('seaailabs/poolformer_s12') + >>> feature_extractor = PoolFormerFeatureExtractor() + >>> model = PoolFormerModel.from_pretrained('sail/poolformer_s12') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) @@ -417,10 +406,9 @@ def forward( return_dict=return_dict, ) sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - + if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + return (sequence_output, None) + encoder_outputs[1:] return PoolFormerModelOutput( last_hidden_state=sequence_output, From f2bce98d2262b6c0ab544b82ec93d2c203448644 Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Tue, 8 Feb 2022 20:28:35 +0530 Subject: [PATCH 05/17] Fixed errors in modeling_auto.py --- src/transformers/__init__.py | 1 - src/transformers/models/auto/modeling_auto.py | 5 ----- tests/test_modeling_poolformer.py | 2 +- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5fc467c2c020b..9f7d2b02871fe 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -648,7 +648,6 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure - _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 31cfaed96ac7d..ee38a1fbd81f9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -162,7 +162,6 @@ [ # Model with LM heads mapping - ("poolformer", "PoolFormerForConditionalGeneration"), ("yoso", "YosoForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"), @@ -217,7 +216,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - ("poolformer", "PoolFormerForCausalLM"), ("xglm", "XGLMForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), ("trocr", "TrOCRForCausalLM"), @@ -355,7 +353,6 @@ [ # Model for Seq2Seq Causal LM mapping - ("poolformer", "PoolFormerForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"), ("led", "LEDForConditionalGeneration"), @@ -384,7 +381,6 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - ("poolformer", "PoolFormerForSequenceClassification"), ("yoso", "YosoForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), @@ -435,7 +431,6 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - ("poolformer", "PoolFormerForQuestionAnswering"), ("yoso", "YosoForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py index 4b4a178665f2a..399bd71427095 100644 --- a/tests/test_modeling_poolformer.py +++ b/tests/test_modeling_poolformer.py @@ -43,7 +43,7 @@ if is_vision_available(): from PIL import Image - from transformers import DeiTFeatureExtractor + from transformers import PoolFormerFeatureExtractor class PoolFormerConfigTester(ConfigTester): From f2c4758874e0e926c02a4c221844965c50a9efd4 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 9 Feb 2022 10:59:41 -0500 Subject: [PATCH 06/17] Fix feature extractor, convert docs to Markdown, styling of code --- docs/source/index.mdx | 1 + docs/source/model_doc/poolformer.mdx | 12 +- src/transformers/__init__.py | 8 +- src/transformers/models/auto/modeling_auto.py | 3 +- .../models/poolformer/__init__.py | 14 +- .../poolformer/configuration_poolformer.py | 74 +++---- .../convert_poolformer_timm_to_pytorch.py | 61 +++--- .../feature_extraction_poolformer.py | 32 +-- .../models/poolformer/modeling_poolformer.py | 187 ++++++++---------- src/transformers/utils/dummy_pt_objects.py | 24 +++ .../utils/dummy_vision_objects.py | 7 + tests/test_feature_extraction_poolformer.py | 69 ++----- tests/test_modeling_poolformer.py | 12 +- utils/check_repo.py | 6 +- 14 files changed, 227 insertions(+), 283 deletions(-) diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 9ee4377110cd8..3f93100ce9ef8 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow. | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | | Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ | | Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ | +| PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/model_doc/poolformer.mdx b/docs/source/model_doc/poolformer.mdx index 35422a32db62c..d4e24158b8ebe 100644 --- a/docs/source/model_doc/poolformer.mdx +++ b/docs/source/model_doc/poolformer.mdx @@ -27,18 +27,16 @@ Tips: This model was contributed by [heytanay](`__ architecture. + This is the configuration class to store the configuration of [`PoolFormerModel`]. It is used to instantiate a + PoolFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PoolFormer + [sail/poolformer_s12](https://huggingface.co/sail/poolformer_s12) architecture. - Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used - to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` - for more information. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Args: - num_channels (:obj:`int`, optional, defaults to 3): + num_channels (`int`, *optional*, defaults to 3): The number of channels in the input image. - patch_size (:obj:`int`, optional, defaults to 16): + patch_size (`int`, *optional*, defaults to 16): The size of the input patch. - stride (:obj:`int`, optional, defaults to 16): + stride (`int`, *optional*, defaults to 16): The stride of the input patch. - pool_size (:obj:`int`, optional, defaults to 3): + pool_size (`int`, *optional*, defaults to 3): The size of the pooling window. - mlp_ratio (:obj:`float`, optional, defaults to 4.0): + mlp_ratio (`float`, *optional*, defaults to 4.0): The ratio of the number of channels in the output of the MLP to the number of channels in the input. - depths (:obj:`list`, optional, defaults to [2, 2, 6, 2]): + depths (`list`, *optional*, defaults to [2, 2, 6, 2]): The depth of each encoder block. - hidden_sizes (:obj:`list`, optional, defaults to [64, 128, 320, 512]): + hidden_sizes (`list`, *optional*, defaults to [64, 128, 320, 512]): The hidden sizes of each encoder block. - patch_sizes (:obj:`list`, optional, defaults to [7, 3, 3, 3]): + patch_sizes (`list`, *optional*, defaults to [7, 3, 3, 3]): The size of the input patch for each encoder block. - strides (:obj:`list`, optional, defaults to [4, 2, 2, 2]): + strides (`list`, *optional*, defaults to [4, 2, 2, 2]): The stride of the input patch for each encoder block. - padding (:obj:`list`, optional, defaults to [2, 1, 1, 1]): + padding (`list`, *optional*, defaults to [2, 1, 1, 1]): The padding of the input patch for each encoder block. - num_encoder_blocks (:obj:`int`, optional, defaults to 4): + num_encoder_blocks (`int`, *optional*, defaults to 4): The number of encoder blocks. - drop_path_rate (:obj:`float`, optional, defaults to 0.0): + drop_path_rate (`float`, *optional*, defaults to 0.0): The dropout rate for the dropout layers. - hidden_act (:obj:`str`, optional, defaults to "gelu"): + hidden_act (`str`, *optional*, defaults to "gelu"): The activation function for the hidden layers. - use_layer_scale (:obj:`bool`, optional, defaults to True): + use_layer_scale (:`bool`, *optional*, defaults to True): Whether to use layer scale. - layer_scale_init_value (:obj:`float`, optional, defaults to 1e-5): + layer_scale_init_value (:`float`, *optional*, defaults to 1e-5): The initial value for the layer scale. - initializer_range (:obj:`float`, optional, defaults to 0.02): + initializer_range (:`float`, *optional*, defaults to 0.02): The initializer range for the weights. - use_cache (:obj:`bool`, optional, defaults to True): - Whether to use cache. - Example:: - - >>> from transformers import PoolFormerModel, PoolFormerConfig - - >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration - >>> configuration = PoolFormerConfig() - >>> # Initializing a model from the sail/poolformer_s12 style configuration - >>> model = PoolFormerModel(configuration) + Example: - >>> # Accessing the model configuration - >>> configuration = model.config - """ + ```python + >>> from transformers import PoolFormerModel, PoolFormerConfig >>> # Initializing a PoolFormer sail/poolformer_s12 + style configuration >>> configuration = PoolFormerConfig() >>> # Initializing a model from the sail/poolformer_s12 + style configuration >>> model = PoolFormerModel(configuration) >>> # Accessing the model configuration >>> + configuration = model.config""" model_type = "poolformer" - def __init__( self, @@ -106,7 +98,6 @@ def __init__( use_layer_scale=True, layer_scale_init_value=1e-5, initializer_range=0.02, - use_cache=True, **kwargs ): self.num_channels = num_channels @@ -125,7 +116,4 @@ def __init__( self.use_layer_scale = use_layer_scale self.layer_scale_init_value = layer_scale_init_value self.initializer_range = initializer_range - self.use_cache = use_cache - super().__init__( - **kwargs - ) \ No newline at end of file + super().__init__(**kwargs) diff --git a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py index fb8e95fd13987..eebc8b0c5e713 100644 --- a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py +++ b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py @@ -14,45 +14,38 @@ # limitations under the License. """Convert PoolFormer checkpoints from the original repository. URL: https://github.com/sail-sg/poolformer""" -import math import argparse import json from collections import OrderedDict from pathlib import Path -from stat import ST_SIZE import torch from PIL import Image import requests from huggingface_hub import cached_download, hf_hub_url -from transformers import ( - PoolFormerConfig, - PoolFormerForImageClassification, - PoolFormerFeatureExtractor, -) +from transformers import PoolFormerConfig, PoolFormerFeatureExtractor, PoolFormerForImageClassification from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger(__name__) + def replace_key_with_offset(key, offset, original_name, new_name): """ Replaces the key by subtracting the offset from the original layer number """ to_find = original_name.split(".")[0] key_list = key.split(".") - orig_block_num = int(key_list[key_list.index(to_find)-2]) - layer_num = int(key_list[key_list.index(to_find)-1]) + orig_block_num = int(key_list[key_list.index(to_find) - 2]) + layer_num = int(key_list[key_list.index(to_find) - 1]) new_block_num = orig_block_num - offset - - key = key.replace( - f"{orig_block_num}.{layer_num}.{original_name}", - f"block.{new_block_num}.{layer_num}.{new_name}" - ) + + key = key.replace(f"{orig_block_num}.{layer_num}.{original_name}", f"block.{new_block_num}.{layer_num}.{new_name}") return key + def rename_keys(state_dict): new_state_dict = OrderedDict() total_embed_found, patch_emb_offset = 0, 0 @@ -63,7 +56,7 @@ def rename_keys(state_dict): # Works for the first embedding as well as the internal embedding layers if key.endswith("bias") and "patch_embed" not in key: patch_emb_offset += 1 - to_replace = key[:key.find("proj")] + to_replace = key[: key.find("proj")] key = key.replace(to_replace, f"patch_embeddings.{total_embed_found}.") key = key.replace("proj", "projection") if key.endswith("bias"): @@ -87,6 +80,7 @@ def rename_keys(state_dict): new_state_dict[key] = value return new_state_dict + # We will verify our results on a COCO image def prepare_img(): url = "http://images.cocodataset.org/val2017/000000039769.jpg" @@ -94,6 +88,7 @@ def prepare_img(): return image + @torch.no_grad() def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path): """ @@ -119,41 +114,35 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold config.depths = [2, 2, 6, 2] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 - crop_pcent = 0.9 + crop_pct = 0.9 elif size == "s24": config.depths = [4, 4, 12, 4] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 - crop_pcent = 0.9 + crop_pct = 0.9 elif size == "s36": config.depths = [6, 6, 18, 6] config.hidden_sizes = [64, 128, 320, 512] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 - crop_pcent = 0.9 + crop_pct = 0.9 elif size == "m36": config.depths = [6, 6, 18, 6] config.hidden_sizes = [96, 192, 384, 768] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 - crop_pcent = 0.95 + crop_pct = 0.95 elif size == "m48": config.depths = [8, 8, 24, 8] config.hidden_sizes = [96, 192, 384, 768] config.mlp_ratio = 4.0 config.layer_scale_init_value = 1e-6 - crop_pcent = 0.95 + crop_pct = 0.95 else: raise ValueError(f"Size {size} not supported") - print(size) - # load feature extractor (only resize + normalize) - img_size = (224, 224) - feature_extractor = DeiTFeatureExtractor( - size=img_size, - resample=Image.BILINEAR, - crop_size=224, - ) + # load feature extractor + feature_extractor = PoolFormerFeatureExtractor(crop_pct=crop_pct) # Prepare image image = prepare_img() @@ -166,31 +155,31 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold # rename keys state_dict = rename_keys(state_dict) - + # create HuggingFace model and load state dict model = PoolFormerForImageClassification(config) model.load_state_dict(state_dict) model.eval() # Define feature extractor - feature_extractor = PoolFormerFeatureExtractor(size=size) - encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + feature_extractor = PoolFormerFeatureExtractor(crop_pct=crop_pct) + pixel_values = feature_extractor(images=prepare_img(), return_tensors="pt").pixel_values # forward pass outputs = model(pixel_values) logits = outputs.logits - # define expected logit slices for different models + # define expected logit slices for different models if size == "s12": expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]) elif size == "s24": - expected_slice = torch.tensor([ 0.4402, -0.1374, -0.8045]) + expected_slice = torch.tensor([0.4402, -0.1374, -0.8045]) elif size == "s36": expected_slice = torch.tensor([-0.6080, -0.5133, -0.5898]) elif size == "m36": - expected_slice = torch.tensor([ 0.3952, 0.2263, -1.2668]) + expected_slice = torch.tensor([0.3952, 0.2263, -1.2668]) elif size == "m48": - expected_slice = torch.tensor([ 0.1167, -0.0656, -0.3423]) + expected_slice = torch.tensor([0.1167, -0.0656, -0.3423]) else: raise ValueError(f"Size {size} not supported") @@ -222,4 +211,4 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." ) args = parser.parse_args() - convert_poolformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path) \ No newline at end of file + convert_poolformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/poolformer/feature_extraction_poolformer.py b/src/transformers/models/poolformer/feature_extraction_poolformer.py index 392d633d06c32..b7d44e2226519 100644 --- a/src/transformers/models/poolformer/feature_extraction_poolformer.py +++ b/src/transformers/models/poolformer/feature_extraction_poolformer.py @@ -14,7 +14,6 @@ # limitations under the License. """Feature extractor class for PoolFormer.""" -from email.policy import default import math from typing import Optional, Union @@ -45,17 +44,18 @@ class PoolFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM Args: do_resize_and_center_crop (`bool`, *optional*, defaults to `True`): - Whether to resize and center crop the input to a certain `size`. + Whether to resize the shortest edge of the image and center crop the input to a certain `size`. size (`int` or `Tuple(int)`, *optional*, defaults to 224): - Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an - integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize_and_center_crop` is - set to `True`. - resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + Center crop the input to the given size. If a tuple is provided, it should be (width, height). If only an + integer is provided, then the input will be center cropped to (size, size). Only has an effect if + `do_resize_and_center_crop` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`): An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect - if `do_resize` is set to `True`. + if `do_resize_and_center_crop` is set to `True`. crop_pct (`float`, *optional*, defaults to `0.9`): - The percentage of the image to crop from the center. + The percentage of the image to crop from the center. Only has an effect if `do_resize_and_center_crop` is + set to `True`. do_normalize (`bool`, *optional*, defaults to `True`): Whether or not to normalize the input with `image_mean` and `image_std`. image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`): @@ -70,7 +70,7 @@ def __init__( self, do_resize_and_center_crop=True, size=224, - resample=Image.BILINEAR, + resample=Image.BICUBIC, crop_pct=0.9, do_normalize=True, image_mean=None, @@ -144,18 +144,24 @@ def __call__( images = [images] # transformations (resizing + center cropping + normalization) - if self.do_resize_and_center_crop and self.size is not None: + if self.do_resize_and_center_crop and self.size is not None and self.crop_pct is not None: if isinstance(self.size, (tuple, list)): assert len(self.size) == 2 if self.size[-1] == self.size[-2]: - # fall-back to older behaviour so Resize scales to shortest edge if target is square scale_size = int(math.floor(self.size[0] / self.crop_pct)) else: scale_size = tuple([int(x / self.crop_pct) for x in self.size]) else: scale_size = int(math.floor(self.size / self.crop_pct)) - images = [self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False) for image in images] + # resize shortest edge of the image + images = [ + self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False) + for image in images + ] + # center crop + images = [self.center_crop(image, size=self.size) for image in images] + if self.do_normalize: images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] @@ -163,4 +169,4 @@ def __call__( data = {"pixel_values": images} encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) - return encoded_inputs \ No newline at end of file + return encoded_inputs diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 5160a734fc840..7ad8e66762dfd 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -12,22 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch PoolFormer model. """ +""" PyTorch PoolFormer model.""" import collections.abc -import math from dataclasses import dataclass from typing import Optional, Tuple import torch -from torch.functional import block_diag, norm import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_poolformer import PoolFormerConfig @@ -35,66 +38,80 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "seaailabs/poolformer_s12" +# General docstring _CONFIG_FOR_DOC = "PoolFormerConfig" +_FEAT_EXTRACTOR_FOR_DOC = "PoolFormerFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "sail/poolformer_s12" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "sail/poolformer_s12" +_IMAGE_CLASS_EXPECTED_OUTPUT = "'Egyptian cat'" POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "sail/poolformer_s12", # See all PoolFormer models at https://huggingface.co/models?filter=poolformer ] + # Copied from transformers.models.vit.modeling_vit.to_2tuple def to_2tuple(x): if isinstance(x, collections.abc.Iterable): return x return (x, x) + @dataclass class PoolFormerModelOutput(ModelOutput): """ Class for PoolFormerModel's outputs, with potential hidden states. Args: - last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. """ + last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + @dataclass class PoolFormerClassifierOutput(ModelOutput): """ Class for PoolformerForImageClassification's outputs. Args: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification (or regression if config.num_labels==1) loss. - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, num_channels, height, width)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. """ + loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None -def drop_path(x, drop_prob: float = 0., training: bool = False): + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is + misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: + https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and + argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets @@ -117,21 +134,16 @@ def forward(self, x): class PoolFormerEmbeddings(nn.Module): """ - Construct Patch Embeddings + Construct Patch Embeddings. """ + def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) - self.projection = nn.Conv2d( - num_channels, - hidden_size, - kernel_size=patch_size, - stride=stride, - padding=padding - ) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity() def forward(self, pixel_values): @@ -142,9 +154,9 @@ def forward(self, pixel_values): class PoolFormerGroupNorm(nn.GroupNorm): """ - Group Normalization with 1 group. - Input: tensor in shape [B, C, H, W] + Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ + def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) @@ -152,12 +164,7 @@ def __init__(self, num_channels, **kwargs): class PoolFormerPooling(nn.Module): def __init__(self, pool_size): super().__init__() - self.pool = nn.AvgPool2d( - pool_size, - stride=1, - padding=pool_size//2, - count_include_pad=False - ) + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) def forward(self, hidden_states): return self.pool(hidden_states) - hidden_states @@ -173,7 +180,7 @@ def __init__(self, config, dropout_prob, hidden_size, intermediate_size): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act - + def forward(self, hidden_states): hidden_states = self.conv1(hidden_states) hidden_states = self.act_fn(hidden_states) @@ -183,6 +190,7 @@ def forward(self, hidden_states): return hidden_states + class PoolFormerLayer(nn.Module): """This corresponds to the 'PoolFormerBlock' class in the original implementation.""" @@ -192,9 +200,9 @@ def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_si self.output = PoolFormerOutput(config, drop_path, hidden_size, intermediate_size) self.before_norm = PoolFormerGroupNorm(num_channels) self.after_norm = PoolFormerGroupNorm(num_channels) - + # Useful for training neural nets - self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.use_layer_scale = config.use_layer_scale if config.use_layer_scale: self.layer_scale_1 = nn.Parameter( @@ -219,7 +227,7 @@ def forward(self, hidden_states): outputs = (output,) + outputs return outputs - + else: pooling_output = self.drop_path(self.pooling(self.before_norm(hidden_states))) # First residual connection @@ -229,7 +237,7 @@ def forward(self, hidden_states): # Second residual connection inside the PoolFormerOutput block layer_output = self.drop_path(self.output(self.after_norm(hidden_states))) output = hidden_states + layer_output - + outputs = (output,) + outputs return outputs @@ -295,10 +303,10 @@ def forward( for i, blk in enumerate(block_layer): layer_outputs = blk(hidden_states) hidden_states = layer_outputs[0] - + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - + if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) @@ -335,14 +343,14 @@ def _set_gradient_checkpointing(self, module, value=False): POOLFORMER_START_DOCSTRING = r""" - This model is a PyTorch `torch.nn.Module `_ sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. Parameters: - config (:class:`~transformers.PoolFormerConfig`): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. + config ([`PoolFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ POOLFORMER_INPUTS_DOCSTRING = r""" @@ -371,27 +379,15 @@ def get_input_embeddings(self): return self.embeddings.patch_embeddings @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=PoolFormerModelOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=PoolFormerModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): - r""" - Returns: - - Examples:: - - >>> from transformers import PoolFormerFeatureExtractor, PoolFormerModel - >>> from PIL import Image - >>> import requests - - >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> feature_extractor = PoolFormerFeatureExtractor() - >>> model = PoolFormerModel.from_pretrained('sail/poolformer_s12') - - >>> inputs = feature_extractor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -406,7 +402,7 @@ def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None return_dict=return_dict, ) sequence_output = encoder_outputs[0] - + if not return_dict: return (sequence_output, None) + encoder_outputs[1:] @@ -441,13 +437,21 @@ def __init__(self, config): # Final norm self.norm = PoolFormerGroupNorm(config.hidden_sizes[-1]) # Classifier head - self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=PoolFormerClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=PoolFormerClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) def forward( self, pixel_values=None, @@ -456,31 +460,10 @@ def forward( return_dict=None, ): r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., - config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples:: - - >>> from transformers import ViTFeatureExtractor, PoolFormerForImageClassification - >>> from PIL import Image - >>> import requests - - >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') - >>> model = PoolFormerForImageClassification.from_pretrained('seaailabs/poolformer_s12') - - >>> inputs = feature_extractor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -491,7 +474,7 @@ def forward( ) sequence_output = outputs[0] - + logits = self.classifier(self.norm(sequence_output).mean([-2, -1])) loss = None @@ -512,4 +495,4 @@ def forward( loss=loss, logits=logits, hidden_states=outputs.hidden_states, - ) \ No newline at end of file + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0741e42861efa..be08c8f0e29be 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2769,6 +2769,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class PoolFormerForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PoolFormerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PoolFormerPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 1084d2fc4d0c6..e49088d8b88ee 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -87,6 +87,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class PoolFormerFeatureExtractor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SegformerFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/test_feature_extraction_poolformer.py b/tests/test_feature_extraction_poolformer.py index f2f55575edf3c..cec912846c68c 100644 --- a/tests/test_feature_extraction_poolformer.py +++ b/tests/test_feature_extraction_poolformer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import unittest import numpy as np @@ -32,24 +31,6 @@ from transformers import PoolFormerFeatureExtractor -def calc_cropped_sizes(images, crop_pct=0.9, size=224, pil=False): - """Calculates and returns a list of the expected sizes of all the cropped images. - """ - size = int(math.floor(size / crop_pct)) - image_shapes = [] - for img in images: - if pil: - height, width = img.size - else: - width, height = img.shape[-2], img.shape[-1] - short, long = (width, height) if width <= height else (height, width) - if short == size: - image_shapes.append((width, height)) - else: - new_short, new_long = size, int(size * long / short) - new_size = (new_short, new_long) if width <= height else (new_long, new_short) - image_shapes.append(new_size) - return image_shapes class PoolFormerFeatureExtractionTester(unittest.TestCase): def __init__( @@ -60,7 +41,7 @@ def __init__( min_resolution=30, max_resolution=400, do_resize_and_center_crop=True, - size=224, + size=30, crop_pct=0.9, do_normalize=True, image_mean=[0.5, 0.5, 0.5], @@ -122,14 +103,6 @@ def test_call_pil(self): for image in image_inputs: self.assertIsInstance(image, Image.Image) - # Calculate the expected sizes of all the images - expected_sizes = calc_cropped_sizes( - image_inputs, - self.feature_extract_tester.crop_pct, - self.feature_extract_tester.size, - pil=True - ) - # Test not batched input encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values @@ -138,8 +111,8 @@ def test_call_pil(self): ( 1, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), ) @@ -150,8 +123,8 @@ def test_call_pil(self): ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), ) @@ -163,13 +136,6 @@ def test_call_numpy(self): for image in image_inputs: self.assertIsInstance(image, np.ndarray) - # Calculate the expected sizes of all the images - expected_sizes = calc_cropped_sizes( - image_inputs, - self.feature_extract_tester.crop_pct, - self.feature_extract_tester.size - ) - # Test not batched input encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values self.assertEqual( @@ -177,8 +143,8 @@ def test_call_numpy(self): ( 1, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), ) @@ -189,8 +155,8 @@ def test_call_numpy(self): ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), ) @@ -202,13 +168,6 @@ def test_call_pytorch(self): for image in image_inputs: self.assertIsInstance(image, torch.Tensor) - # Calculate the expected sizes of all the images - expected_sizes = calc_cropped_sizes( - image_inputs, - self.feature_extract_tester.crop_pct, - self.feature_extract_tester.size - ) - # Test not batched input encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values self.assertEqual( @@ -216,8 +175,8 @@ def test_call_pytorch(self): ( 1, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), ) @@ -228,7 +187,7 @@ def test_call_pytorch(self): ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - expected_sizes[0][0], - expected_sizes[0][1], + self.feature_extract_tester.size, + self.feature_extract_tester.size, ), - ) \ No newline at end of file + ) diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py index 399bd71427095..085a2e27e9998 100644 --- a/tests/test_modeling_poolformer.py +++ b/tests/test_modeling_poolformer.py @@ -17,8 +17,7 @@ import inspect import unittest - -from typing import List, Tuple, Dict +from typing import Dict, List, Tuple from transformers import is_torch_available, is_vision_available from transformers.models.auto import get_values @@ -31,12 +30,7 @@ if is_torch_available(): import torch - from transformers import ( - MODEL_MAPPING, - PoolFormerConfig, - PoolFormerForImageClassification, - PoolFormerModel, - ) + from transformers import MODEL_MAPPING, PoolFormerConfig, PoolFormerForImageClassification, PoolFormerModel from transformers.models.poolformer.modeling_poolformer import POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST @@ -354,4 +348,4 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]).to(torch_device) - self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) \ No newline at end of file + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 056a9a8abf2f4..49c7641584589 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -44,9 +44,9 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested -"PoolFormerEncoder", # Building part of bigger (tested) model. + "PoolFormerEncoder", # Building part of bigger (tested) model. "PoolFormerDecoder", # Building part of bigger (tested) model. - "PoolFormerDecoderWrapper", # Building part of bigger (tested) model. + "PoolFormerDecoderWrapper", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model. @@ -111,7 +111,7 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping -"PoolFormerEncoder", + "PoolFormerEncoder", "PoolFormerDecoder", "PoolFormerDecoderWrapper", "ViltForQuestionAnswering", From 1e4cc06ff58cdf8b233d1894add886becada492d Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 9 Feb 2022 11:20:04 -0500 Subject: [PATCH 07/17] Remove PoolFormer from check_repo and fix integration test --- .../models/poolformer/inference.py | 19 +++++++++++++++++++ .../models/poolformer/modeling_poolformer.py | 2 +- tests/test_modeling_poolformer.py | 12 +++--------- utils/check_repo.py | 2 -- 4 files changed, 23 insertions(+), 12 deletions(-) create mode 100644 src/transformers/models/poolformer/inference.py diff --git a/src/transformers/models/poolformer/inference.py b/src/transformers/models/poolformer/inference.py new file mode 100644 index 0000000000000..2106eacaca1bf --- /dev/null +++ b/src/transformers/models/poolformer/inference.py @@ -0,0 +1,19 @@ +from transformers import PoolFormerFeatureExtractor, PoolFormerForImageClassification +import requests +from PIL import Image + +feature_extractor = PoolFormerFeatureExtractor() + +model = PoolFormerForImageClassification.from_pretrained("sail/poolformer_s12") + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +inputs = feature_extractor(images=image, return_tensors="pt") + +outputs = model(**inputs) + +predicted_class = outputs.logits.argmax(-1).item() +print(model.config.id2label[predicted_class]) + +print(outputs.logits[0,:3]) \ No newline at end of file diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 7ad8e66762dfd..90a797a7485a2 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -48,7 +48,7 @@ # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "sail/poolformer_s12" -_IMAGE_CLASS_EXPECTED_OUTPUT = "'Egyptian cat'" +_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'" POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "sail/poolformer_s12", diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py index 085a2e27e9998..64e9f6755051d 100644 --- a/tests/test_modeling_poolformer.py +++ b/tests/test_modeling_poolformer.py @@ -328,15 +328,10 @@ def prepare_img(): class PoolFormerModelIntegrationTest(unittest.TestCase): @slow def test_inference_image_classification_head(self): + feature_extractor = PoolFormerFeatureExtractor() model = PoolFormerForImageClassification.from_pretrained("sail/poolformer_s12").to(torch_device) - img_size = (224, 224) - feature_extractor = PoolFormerFeatureExtractor( - size=img_size, - ) - - image = prepare_img() - inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + inputs = feature_extractor(images=prepare_img(), return_tensors="pt").to(torch_device) # forward pass with torch.no_grad(): @@ -346,6 +341,5 @@ def test_inference_image_classification_head(self): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]).to(torch_device) - + expected_slice = torch.tensor([-0.6113, 0.1685, -0.0492]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 49c7641584589..5013662a599a6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -44,8 +44,6 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested - "PoolFormerEncoder", # Building part of bigger (tested) model. - "PoolFormerDecoder", # Building part of bigger (tested) model. "PoolFormerDecoderWrapper", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model. From 749425a767343e75f580e493064648808576ed61 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 9 Feb 2022 11:22:22 -0500 Subject: [PATCH 08/17] Remove Poolformer from check_repo --- src/transformers/models/poolformer/inference.py | 10 ++++++---- utils/check_repo.py | 4 ---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/poolformer/inference.py b/src/transformers/models/poolformer/inference.py index 2106eacaca1bf..dca74b60b86a3 100644 --- a/src/transformers/models/poolformer/inference.py +++ b/src/transformers/models/poolformer/inference.py @@ -1,12 +1,14 @@ -from transformers import PoolFormerFeatureExtractor, PoolFormerForImageClassification -import requests from PIL import Image +import requests +from transformers import PoolFormerFeatureExtractor, PoolFormerForImageClassification + + feature_extractor = PoolFormerFeatureExtractor() model = PoolFormerForImageClassification.from_pretrained("sail/poolformer_s12") -url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = feature_extractor(images=image, return_tensors="pt") @@ -16,4 +18,4 @@ predicted_class = outputs.logits.argmax(-1).item() print(model.config.id2label[predicted_class]) -print(outputs.logits[0,:3]) \ No newline at end of file +print(outputs.logits[0, :3]) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5013662a599a6..9ee2266ca7366 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -44,7 +44,6 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested - "PoolFormerDecoderWrapper", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model. @@ -109,9 +108,6 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping - "PoolFormerEncoder", - "PoolFormerDecoder", - "PoolFormerDecoderWrapper", "ViltForQuestionAnswering", "ViltForImagesAndTextClassification", "ViltForImageAndTextRetrieval", From 6aec08494a94ed4c8445c374bc46d2ddb2521d96 Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Wed, 9 Feb 2022 23:09:55 +0530 Subject: [PATCH 09/17] Fixed configuration_poolformer.py docs and removed inference.py from poolformer --- src/transformers/__init__.py | 103 +++--------------- src/transformers/models/auto/modeling_auto.py | 10 +- .../poolformer/configuration_poolformer.py | 16 ++- .../convert_poolformer_timm_to_pytorch.py | 5 +- .../models/poolformer/inference.py | 21 ---- .../models/poolformer/modeling_poolformer.py | 37 ++----- tests/test_modeling_poolformer.py | 9 +- 7 files changed, 40 insertions(+), 161 deletions(-) delete mode 100644 src/transformers/models/poolformer/inference.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8efc520dc0136..98acebd6688c4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -179,10 +179,7 @@ "models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], "models.bertweet": ["BertweetTokenizer"], "models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], - "models.bigbird_pegasus": [ - "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", - "BigBirdPegasusConfig", - ], + "models.bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig",], "models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], "models.blenderbot_small": [ "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -255,10 +252,7 @@ "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], "models.mt5": ["MT5Config"], - "models.nystromformer": [ - "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", - "NystromformerConfig", - ], + "models.nystromformer": ["NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NystromformerConfig",], "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], @@ -277,10 +271,7 @@ "models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"], "models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"], "models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"], - "models.speech_to_text": [ - "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", - "Speech2TextConfig", - ], + "models.speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig",], "models.speech_to_text_2": [ "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2Text2Config", @@ -298,19 +289,9 @@ "TransfoXLCorpus", "TransfoXLTokenizer", ], - "models.trocr": [ - "TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", - "TrOCRConfig", - "TrOCRProcessor", - ], - "models.unispeech": [ - "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", - "UniSpeechConfig", - ], - "models.unispeech_sat": [ - "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", - "UniSpeechSatConfig", - ], + "models.trocr": ["TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", "TrOCRConfig", "TrOCRProcessor",], + "models.unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig",], + "models.unispeech_sat": ["UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechSatConfig",], "models.vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig", "ViltFeatureExtractor", "ViltProcessor"], "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"], "models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"], @@ -327,10 +308,7 @@ ], "models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"], "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"], - "models.wavlm": [ - "WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", - "WavLMConfig", - ], + "models.wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig",], "models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], @@ -826,13 +804,7 @@ ] ) _import_structure["models.clip"].extend( - [ - "CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", - "CLIPModel", - "CLIPPreTrainedModel", - "CLIPTextModel", - "CLIPVisionModel", - ] + ["CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", "CLIPModel", "CLIPPreTrainedModel", "CLIPTextModel", "CLIPVisionModel",] ) _import_structure["models.convbert"].extend( [ @@ -1372,12 +1344,7 @@ ] ) _import_structure["models.swin"].extend( - [ - "SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", - "SwinForImageClassification", - "SwinModel", - "SwinPreTrainedModel", - ] + ["SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", "SwinForImageClassification", "SwinModel", "SwinPreTrainedModel",] ) _import_structure["models.t5"].extend( [ @@ -1453,12 +1420,7 @@ ] ) _import_structure["models.vit"].extend( - [ - "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", - "ViTForImageClassification", - "ViTModel", - "ViTPreTrainedModel", - ] + ["VIT_PRETRAINED_MODEL_ARCHIVE_LIST", "ViTForImageClassification", "ViTModel", "ViTPreTrainedModel",] ) _import_structure["models.vit_mae"].extend( [ @@ -1494,12 +1456,7 @@ ] ) _import_structure["models.xglm"].extend( - [ - "XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", - "XGLMForCausalLM", - "XGLMModel", - "XGLMPreTrainedModel", - ] + ["XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", "XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel",] ) _import_structure["models.xlm"].extend( [ @@ -1833,12 +1790,7 @@ ] ) _import_structure["models.hubert"].extend( - [ - "TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", - "TFHubertForCTC", - "TFHubertModel", - "TFHubertPreTrainedModel", - ] + ["TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TFHubertForCTC", "TFHubertModel", "TFHubertPreTrainedModel",] ) _import_structure["models.layoutlm"].extend( [ @@ -1923,12 +1875,7 @@ ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] ) _import_structure["models.rag"].extend( - [ - "TFRagModel", - "TFRagPreTrainedModel", - "TFRagSequenceForGeneration", - "TFRagTokenForGeneration", - ] + ["TFRagModel", "TFRagPreTrainedModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration",] ) _import_structure["models.rembert"].extend( [ @@ -2012,11 +1959,7 @@ ) _import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"]) _import_structure["models.vit"].extend( - [ - "TFViTForImageClassification", - "TFViTModel", - "TFViTPreTrainedModel", - ] + ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel",] ) _import_structure["models.wav2vec2"].extend( [ @@ -2224,11 +2167,7 @@ ) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) _import_structure["models.marian"].extend( - [ - "FlaxMarianModel", - "FlaxMarianMTModel", - "FlaxMarianPreTrainedModel", - ] + ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel",] ) _import_structure["models.mbart"].extend( [ @@ -2241,11 +2180,7 @@ ) _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.pegasus"].extend( - [ - "FlaxPegasusForConditionalGeneration", - "FlaxPegasusModel", - "FlaxPegasusPreTrainedModel", - ] + ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel",] ) _import_structure["models.roberta"].extend( [ @@ -2277,11 +2212,7 @@ ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ) _import_structure["models.xglm"].extend( - [ - "FlaxXGLMForCausalLM", - "FlaxXGLMModel", - "FlaxXGLMPreTrainedModel", - ] + ["FlaxXGLMForCausalLM", "FlaxXGLMModel", "FlaxXGLMPreTrainedModel",] ) else: from .utils import dummy_flax_objects diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5e8c8149dd44f..a6769c0345dbc 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -253,9 +253,7 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( # Model for Causal Image Modeling mapping - [ - ("imagegpt", "ImageGPTForCausalImageModeling"), - ] + [("imagegpt", "ImageGPTForCausalImageModeling"),] ) MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( @@ -296,11 +294,7 @@ ] ) -MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( - [ - ("vision-encoder-decoder", "VisionEncoderDecoderModel"), - ] -) +MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict([("vision-encoder-decoder", "VisionEncoderDecoderModel"),]) MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ diff --git a/src/transformers/models/poolformer/configuration_poolformer.py b/src/transformers/models/poolformer/configuration_poolformer.py index bde4575de4960..2e131a5f31d34 100644 --- a/src/transformers/models/poolformer/configuration_poolformer.py +++ b/src/transformers/models/poolformer/configuration_poolformer.py @@ -74,10 +74,18 @@ class PoolFormerConfig(PretrainedConfig): Example: ```python - >>> from transformers import PoolFormerModel, PoolFormerConfig >>> # Initializing a PoolFormer sail/poolformer_s12 - style configuration >>> configuration = PoolFormerConfig() >>> # Initializing a model from the sail/poolformer_s12 - style configuration >>> model = PoolFormerModel(configuration) >>> # Accessing the model configuration >>> - configuration = model.config""" + >>> from transformers import PoolFormerModel, PoolFormerConfig + + >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration + >>> configuration = PoolFormerConfig() + + >>> # Initializing a model from the sail/poolformer_s12 style configuration + >>> model = PoolFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ model_type = "poolformer" def __init__( diff --git a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py index eebc8b0c5e713..2af22f1200d42 100644 --- a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py +++ b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py @@ -199,10 +199,7 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold parser = argparse.ArgumentParser() parser.add_argument( - "--model_name", - default="poolformer_s12", - type=str, - help="Name of the model you'd like to convert.", + "--model_name", default="poolformer_s12", type=str, help="Name of the model you'd like to convert.", ) parser.add_argument( "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)." diff --git a/src/transformers/models/poolformer/inference.py b/src/transformers/models/poolformer/inference.py deleted file mode 100644 index dca74b60b86a3..0000000000000 --- a/src/transformers/models/poolformer/inference.py +++ /dev/null @@ -1,21 +0,0 @@ -from PIL import Image - -import requests -from transformers import PoolFormerFeatureExtractor, PoolFormerForImageClassification - - -feature_extractor = PoolFormerFeatureExtractor() - -model = PoolFormerForImageClassification.from_pretrained("sail/poolformer_s12") - -url = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = Image.open(requests.get(url, stream=True).raw) - -inputs = feature_extractor(images=image, return_tensors="pt") - -outputs = model(**inputs) - -predicted_class = outputs.logits.argmax(-1).item() -print(model.config.id2label[predicted_class]) - -print(outputs.logits[0, :3]) diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 90a797a7485a2..ddb69264868c5 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -287,10 +287,7 @@ def __init__(self, config): self.block = nn.ModuleList(blocks) def forward( - self, - pixel_values, - output_hidden_states=False, - return_dict=True, + self, pixel_values, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None @@ -310,10 +307,7 @@ def forward( if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - return PoolFormerModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - ) + return PoolFormerModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states,) class PoolFormerPreTrainedModel(PreTrainedModel): @@ -397,19 +391,14 @@ def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] if not return_dict: return (sequence_output, None) + encoder_outputs[1:] - return PoolFormerModelOutput( - last_hidden_state=sequence_output, - hidden_states=encoder_outputs.hidden_states, - ) + return PoolFormerModelOutput(last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states,) class PoolFormerFinalPooler(nn.Module): @@ -453,11 +442,7 @@ def __init__(self, config): expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( - self, - pixel_values=None, - labels=None, - output_hidden_states=None, - return_dict=None, + self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -467,11 +452,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.poolformer( - pixel_values, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + outputs = self.poolformer(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict,) sequence_output = outputs[0] @@ -491,8 +472,4 @@ def forward( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return PoolFormerClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - ) + return PoolFormerClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,) diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py index 64e9f6755051d..0ab8df008dc3a 100644 --- a/tests/test_modeling_poolformer.py +++ b/tests/test_modeling_poolformer.py @@ -124,14 +124,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - PoolFormerModel, - PoolFormerForImageClassification, - ) - if is_torch_available() - else () - ) + all_model_classes = (PoolFormerModel, PoolFormerForImageClassification,) if is_torch_available() else () test_head_masking = False test_pruning = False From 2d18e0e2cb1fa060b3c1c683bcef5014de5ed07e Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Wed, 9 Feb 2022 23:20:23 +0530 Subject: [PATCH 10/17] Ran with black v22 --- src/transformers/__init__.py | 103 +++++++++++++++--- src/transformers/models/auto/modeling_auto.py | 10 +- .../poolformer/configuration_poolformer.py | 12 +- .../convert_poolformer_timm_to_pytorch.py | 5 +- .../models/poolformer/modeling_poolformer.py | 37 +++++-- tests/test_modeling_poolformer.py | 9 +- 6 files changed, 142 insertions(+), 34 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 98acebd6688c4..8efc520dc0136 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -179,7 +179,10 @@ "models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], "models.bertweet": ["BertweetTokenizer"], "models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], - "models.bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig",], + "models.bigbird_pegasus": [ + "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BigBirdPegasusConfig", + ], "models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], "models.blenderbot_small": [ "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -252,7 +255,10 @@ "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], "models.mt5": ["MT5Config"], - "models.nystromformer": ["NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NystromformerConfig",], + "models.nystromformer": [ + "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", + "NystromformerConfig", + ], "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], @@ -271,7 +277,10 @@ "models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"], "models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"], "models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"], - "models.speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig",], + "models.speech_to_text": [ + "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Speech2TextConfig", + ], "models.speech_to_text_2": [ "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2Text2Config", @@ -289,9 +298,19 @@ "TransfoXLCorpus", "TransfoXLTokenizer", ], - "models.trocr": ["TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", "TrOCRConfig", "TrOCRProcessor",], - "models.unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig",], - "models.unispeech_sat": ["UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechSatConfig",], + "models.trocr": [ + "TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", + "TrOCRConfig", + "TrOCRProcessor", + ], + "models.unispeech": [ + "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", + "UniSpeechConfig", + ], + "models.unispeech_sat": [ + "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "UniSpeechSatConfig", + ], "models.vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig", "ViltFeatureExtractor", "ViltProcessor"], "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"], "models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"], @@ -308,7 +327,10 @@ ], "models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"], "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"], - "models.wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig",], + "models.wavlm": [ + "WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", + "WavLMConfig", + ], "models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], @@ -804,7 +826,13 @@ ] ) _import_structure["models.clip"].extend( - ["CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", "CLIPModel", "CLIPPreTrainedModel", "CLIPTextModel", "CLIPVisionModel",] + [ + "CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", + "CLIPModel", + "CLIPPreTrainedModel", + "CLIPTextModel", + "CLIPVisionModel", + ] ) _import_structure["models.convbert"].extend( [ @@ -1344,7 +1372,12 @@ ] ) _import_structure["models.swin"].extend( - ["SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", "SwinForImageClassification", "SwinModel", "SwinPreTrainedModel",] + [ + "SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", + "SwinForImageClassification", + "SwinModel", + "SwinPreTrainedModel", + ] ) _import_structure["models.t5"].extend( [ @@ -1420,7 +1453,12 @@ ] ) _import_structure["models.vit"].extend( - ["VIT_PRETRAINED_MODEL_ARCHIVE_LIST", "ViTForImageClassification", "ViTModel", "ViTPreTrainedModel",] + [ + "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "ViTForImageClassification", + "ViTModel", + "ViTPreTrainedModel", + ] ) _import_structure["models.vit_mae"].extend( [ @@ -1456,7 +1494,12 @@ ] ) _import_structure["models.xglm"].extend( - ["XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", "XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel",] + [ + "XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "XGLMForCausalLM", + "XGLMModel", + "XGLMPreTrainedModel", + ] ) _import_structure["models.xlm"].extend( [ @@ -1790,7 +1833,12 @@ ] ) _import_structure["models.hubert"].extend( - ["TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TFHubertForCTC", "TFHubertModel", "TFHubertPreTrainedModel",] + [ + "TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFHubertForCTC", + "TFHubertModel", + "TFHubertPreTrainedModel", + ] ) _import_structure["models.layoutlm"].extend( [ @@ -1875,7 +1923,12 @@ ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] ) _import_structure["models.rag"].extend( - ["TFRagModel", "TFRagPreTrainedModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration",] + [ + "TFRagModel", + "TFRagPreTrainedModel", + "TFRagSequenceForGeneration", + "TFRagTokenForGeneration", + ] ) _import_structure["models.rembert"].extend( [ @@ -1959,7 +2012,11 @@ ) _import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"]) _import_structure["models.vit"].extend( - ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel",] + [ + "TFViTForImageClassification", + "TFViTModel", + "TFViTPreTrainedModel", + ] ) _import_structure["models.wav2vec2"].extend( [ @@ -2167,7 +2224,11 @@ ) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) _import_structure["models.marian"].extend( - ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel",] + [ + "FlaxMarianModel", + "FlaxMarianMTModel", + "FlaxMarianPreTrainedModel", + ] ) _import_structure["models.mbart"].extend( [ @@ -2180,7 +2241,11 @@ ) _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.pegasus"].extend( - ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel",] + [ + "FlaxPegasusForConditionalGeneration", + "FlaxPegasusModel", + "FlaxPegasusPreTrainedModel", + ] ) _import_structure["models.roberta"].extend( [ @@ -2212,7 +2277,11 @@ ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ) _import_structure["models.xglm"].extend( - ["FlaxXGLMForCausalLM", "FlaxXGLMModel", "FlaxXGLMPreTrainedModel",] + [ + "FlaxXGLMForCausalLM", + "FlaxXGLMModel", + "FlaxXGLMPreTrainedModel", + ] ) else: from .utils import dummy_flax_objects diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a6769c0345dbc..5e8c8149dd44f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -253,7 +253,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( # Model for Causal Image Modeling mapping - [("imagegpt", "ImageGPTForCausalImageModeling"),] + [ + ("imagegpt", "ImageGPTForCausalImageModeling"), + ] ) MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( @@ -294,7 +296,11 @@ ] ) -MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict([("vision-encoder-decoder", "VisionEncoderDecoderModel"),]) +MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("vision-encoder-decoder", "VisionEncoderDecoderModel"), + ] +) MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ diff --git a/src/transformers/models/poolformer/configuration_poolformer.py b/src/transformers/models/poolformer/configuration_poolformer.py index 2e131a5f31d34..ebff75b4f42c5 100644 --- a/src/transformers/models/poolformer/configuration_poolformer.py +++ b/src/transformers/models/poolformer/configuration_poolformer.py @@ -74,15 +74,15 @@ class PoolFormerConfig(PretrainedConfig): Example: ```python - >>> from transformers import PoolFormerModel, PoolFormerConfig + >>> from transformers import PoolFormerModel, PoolFormerConfig - >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration - >>> configuration = PoolFormerConfig() + >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration + >>> configuration = PoolFormerConfig() - >>> # Initializing a model from the sail/poolformer_s12 style configuration - >>> model = PoolFormerModel(configuration) + >>> # Initializing a model from the sail/poolformer_s12 style configuration + >>> model = PoolFormerModel(configuration) - >>> # Accessing the model configuration + >>> # Accessing the model configuration >>> configuration = model.config ``` """ diff --git a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py index 2af22f1200d42..eebc8b0c5e713 100644 --- a/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py +++ b/src/transformers/models/poolformer/convert_poolformer_timm_to_pytorch.py @@ -199,7 +199,10 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold parser = argparse.ArgumentParser() parser.add_argument( - "--model_name", default="poolformer_s12", type=str, help="Name of the model you'd like to convert.", + "--model_name", + default="poolformer_s12", + type=str, + help="Name of the model you'd like to convert.", ) parser.add_argument( "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)." diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index ddb69264868c5..90a797a7485a2 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -287,7 +287,10 @@ def __init__(self, config): self.block = nn.ModuleList(blocks) def forward( - self, pixel_values, output_hidden_states=False, return_dict=True, + self, + pixel_values, + output_hidden_states=False, + return_dict=True, ): all_hidden_states = () if output_hidden_states else None @@ -307,7 +310,10 @@ def forward( if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) - return PoolFormerModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states,) + return PoolFormerModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) class PoolFormerPreTrainedModel(PreTrainedModel): @@ -391,14 +397,19 @@ def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( - pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) sequence_output = encoder_outputs[0] if not return_dict: return (sequence_output, None) + encoder_outputs[1:] - return PoolFormerModelOutput(last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states,) + return PoolFormerModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + ) class PoolFormerFinalPooler(nn.Module): @@ -442,7 +453,11 @@ def __init__(self, config): expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( - self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None, + self, + pixel_values=None, + labels=None, + output_hidden_states=None, + return_dict=None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -452,7 +467,11 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.poolformer(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict,) + outputs = self.poolformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) sequence_output = outputs[0] @@ -472,4 +491,8 @@ def forward( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return PoolFormerClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,) + return PoolFormerClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) diff --git a/tests/test_modeling_poolformer.py b/tests/test_modeling_poolformer.py index 0ab8df008dc3a..64e9f6755051d 100644 --- a/tests/test_modeling_poolformer.py +++ b/tests/test_modeling_poolformer.py @@ -124,7 +124,14 @@ def prepare_config_and_inputs_for_common(self): @require_torch class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (PoolFormerModel, PoolFormerForImageClassification,) if is_torch_available() else () + all_model_classes = ( + ( + PoolFormerModel, + PoolFormerForImageClassification, + ) + if is_torch_available() + else () + ) test_head_masking = False test_pruning = False From f3233ad1ecc87c9ff67909eb25c3e73b608d344c Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Wed, 9 Feb 2022 23:33:17 +0530 Subject: [PATCH 11/17] Added PoolFormer to _toctree.yml --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 70d2455d0aeaa..0d23eff7ddd39 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -246,6 +246,8 @@ title: Pegasus - local: model_doc/phobert title: PhoBERT + - local: model_doc/poolformer + title: PoolFormer - local: model_doc/prophetnet title: ProphetNet - local: model_doc/qdqbert From d056cbc4ce352814486479a0d9988ce30de5ff3a Mon Sep 17 00:00:00 2001 From: Tanay Mehta Date: Wed, 9 Feb 2022 23:49:44 +0530 Subject: [PATCH 12/17] Updated poolformer doc --- docs/source/model_doc/poolformer.mdx | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/model_doc/poolformer.mdx b/docs/source/model_doc/poolformer.mdx index d4e24158b8ebe..bbed3761ddc2f 100644 --- a/docs/source/model_doc/poolformer.mdx +++ b/docs/source/model_doc/poolformer.mdx @@ -20,9 +20,24 @@ The abstract from the paper is the following: *Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design.* +The figure below illustrates the architecture of SegFormer. Taken from the [original paper](https://arxiv.org/abs/2111.11418). + + + + Tips: - +- PoolFormer has a hierarchical architecture, where instead of Attention, a simple Average Pooling layer is present. All checkpoints of the model can be found on the [hub](https://huggingface.co/models?other=poolformer). +- One can use [`PoolFormerFeatureExtractor`] to prepare images for the model. +- As most models, PoolFormer comes in different sizes, the details of which can be found in the table below. + +| **Model variant** | **Depths** | **Hidden sizes** | **Params (M)** | **ImageNet-1k Top 1** | +| :---------------: | ------------- | ------------------- | :------------: | :-------------------: | +| s12 | [2, 2, 6, 2] | [64, 128, 320, 512] | 12 | 77.2 | +| s24 | [4, 4, 12, 4] | [64, 128, 320, 512] | 21 | 80.3 | +| s36 | [6, 6, 18, 6] | [64, 128, 320, 512] | 31 | 81.4 | +| m36 | [6, 6, 18, 6] | [96, 192, 384, 768] | 56 | 82.1 | +| m48 | [8, 8, 24, 8] | [96, 192, 384, 768] | 73 | 82.5 | This model was contributed by [heytanay]( Date: Thu, 10 Feb 2022 10:52:02 +0530 Subject: [PATCH 13/17] Applied suggested fixes and added on README.md --- README.md | 1 + docs/source/model_doc/poolformer.mdx | 6 ++- .../poolformer/configuration_poolformer.py | 18 ++++---- .../models/poolformer/modeling_poolformer.py | 41 +++++++++++-------- tests/test_modeling_poolformer.py | 13 +----- 5 files changed, 40 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 908d2e690faad..ce9b272bd2dde 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. +1. **[PoolFormer](https://huggingface.co/docs/transformers/model_doc/poolformer)** (from Sea AI Labs) released with the paper [MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418) by Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng. 1. **[ProphetNet](https://huggingface.co/docs/transformers/model_doc/prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. diff --git a/docs/source/model_doc/poolformer.mdx b/docs/source/model_doc/poolformer.mdx index bbed3761ddc2f..a3f9a3b7ba11a 100644 --- a/docs/source/model_doc/poolformer.mdx +++ b/docs/source/model_doc/poolformer.mdx @@ -39,19 +39,23 @@ Tips: | m36 | [6, 6, 18, 6] | [96, 192, 384, 768] | 56 | 82.1 | | m48 | [8, 8, 24, 8] | [96, 192, 384, 768] | 73 | 82.5 | -This model was contributed by [heytanay](