Skip to content

Commit

Permalink
Address comments from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed May 17, 2022
1 parent 41bfc40 commit fe257cf
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 59 deletions.
7 changes: 6 additions & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
### Supported models

<!--This list is updated automatically from the README with _make fix-copies_. Do not update manually! -->

1. **[ALBERT](model_doc/albert)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
1. **[BART](model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BARThez](model_doc/barthez)** (from École polytechnique) released with the paper [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://arxiv.org/abs/2010.12321) by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis.
Expand Down Expand Up @@ -164,13 +165,16 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
1. **[XLS-R](model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
1. **[YOLOS](model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
1. **[YOSO](model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.


### Supported frameworks

The table below represents the current support in the library for each of those models, whether they have a Python
tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in Jax (via
Flax), PyTorch, and/or TensorFlow.

<!--This table is updated automatically from the auto modules with _make fix-copies_. Do not update manually!-->

| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
|:---------------------------:|:--------------:|:--------------:|:---------------:|:------------------:|:------------:|
| ALBERT | | | | | |
Expand Down Expand Up @@ -279,4 +283,5 @@ Flax), PyTorch, and/or TensorFlow.
| XLNet | | | | | |
| YOLOS | | | | | |
| YOSO | | | | | |
<!-- End table-->

<!-- End table-->
10 changes: 3 additions & 7 deletions docs/source/en/model_doc/cvt.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ specific language governing permissions and limitations under the License.

# Convolutional Vision Transformer (CvT)


## Overview

The CvT model was proposed in [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan and Lei Zhang. The Convolutional vision Transformer (CvT) improves the [Vision Transformer (ViT)](vit) in performance and efficiency by introducing convolutions into ViT to yield the best of both designs.



The abstract from the paper is the following:

*We present in this paper a new architecture, named Convolutional vision Transformer (CvT), that improves Vision Transformer (ViT)
Expand All @@ -34,10 +31,9 @@ a crucial component in existing Vision Transformers, can be safely removed in ou

Tips:

- CvT models are regular Vision Transformers, but trained with convolutions. They outperform the [original model (ViT)](vit) when fine-tuned on ImageNet-1K and CIFAR-100. You can check out demo notebooks regarding inference as well as fine-tuning on custom data [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer) (you can just replace [`ViTFeatureExtractor`] by [`AutoFeatureExtractor`] and [`ViTForImageClassification`] by [`CvtForImageClassification`]).
- The available checkpoints are either (1) pre-trained on [ImageNet-22k](http://www.image-net.org/) (a collection of

14 million images and 22k classes) only, (2) also fine-tuned on ImageNet-22k or (3) also fine-tuned on [ImageNet-1k](http://www.image-net.org/challenges/LSVRC/2012/) (also referred to as ILSVRC 2012, a collection of 1.3 million
- CvT models are regular Vision Transformers, but trained with convolutions. They outperform the [original model (ViT)](vit) when fine-tuned on ImageNet-1K and CIFAR-100.
- You can check out demo notebooks regarding inference as well as fine-tuning on custom data [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer) (you can just replace [`ViTFeatureExtractor`] by [`AutoFeatureExtractor`] and [`ViTForImageClassification`] by [`CvtForImageClassification`]).
- The available checkpoints are either (1) pre-trained on [ImageNet-22k](http://www.image-net.org/) (a collection of 14 million images and 22k classes) only, (2) also fine-tuned on ImageNet-22k or (3) also fine-tuned on [ImageNet-1k](http://www.image-net.org/challenges/LSVRC/2012/) (also referred to as ILSVRC 2012, a collection of 1.3 million
images and 1,000 classes).

This model was contributed by [anugunj](https://huggingface.co/anugunj). The original code can be found [here](https://github.com/microsoft/CvT).
Expand Down
4 changes: 3 additions & 1 deletion docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ and are designed to be easily extendable to other architectures.
Ready-made configurations include the following architectures:

<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->

- ALBERT
- BART
- BEiT
Expand Down Expand Up @@ -76,6 +77,7 @@ Ready-made configurations include the following architectures:
- ViT
- XLM-RoBERTa
- XLM-RoBERTa-XL

In the next two sections, we'll show you how to:

* Export a supported model using the `transformers.onnx` package.
Expand Down Expand Up @@ -665,4 +667,4 @@ torch.neuron.trace(model, [token_tensor, segments_tensors])
This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances.

To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
2 changes: 1 addition & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,4 +734,4 @@ def register(model_type, config):
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
"match!"
)
CONFIG_MAPPING.register(model_type, config)
CONFIG_MAPPING.register(model_type, config)
1 change: 1 addition & 0 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
("beit", "BeitFeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("cvt", "ConvNextFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cvt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
27 changes: 8 additions & 19 deletions src/transformers/models/cvt/configuration_cvt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 AnugunjNaman and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Cvt model configuration"""
""" CvT model configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging
Expand All @@ -28,22 +28,17 @@

class CvtConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a Cvt model
This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Cvt
defaults will yield a similar configuration to that of the CvT
[microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_stages (`int`, *optional*, defaults to 3):
The number of encoder blocks (i.e. stages in the Mix Transformer encoder).
patch_sizes (`List[int]`, *optional*, defaults to [7, 3, 3]):
The kernel size of each encoder's patch embedding.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2]):
Expand All @@ -68,11 +63,7 @@ class CvtConfig(PretrainedConfig):
qkv_bias (`List[bool]`, *optional*, defaults to [True, True, True]]):
The bias bool for query, key and value in attentions
cls_token (`List[bool]`, *optional*, defaults to [False, False, True]]):
The bool for classification token
pos_embed (`List[bool]`, *optional*, defaults to [False, False, True]]):
The bool for position embeddings
cls_token (`List[bool]`, *optional*, defaults to [False, False, True]]):
The bool for cls_token
Whether or not to add a classification token to the output of each of the last 3 stages.
qkv_projection_method (`List[string]`, *optional*, defaults to 'dw_bn', 'dw_bn', 'dw_bn']]):
The projection method for query, key and value Default is depth-wise convolutions with batch norm. For
Linear projection use "avg".
Expand All @@ -90,7 +81,9 @@ class CvtConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
Example:
Example:
```python
>>> from transformers import CvtModel, CvtConfig
Expand All @@ -108,7 +101,6 @@ class CvtConfig(PretrainedConfig):

def __init__(
self,
image_size=224,
num_channels=3,
patch_sizes=[7, 3, 3],
patch_stride=[4, 2, 2],
Expand All @@ -122,7 +114,6 @@ def __init__(
drop_path_rate=[0.0, 0.0, 0.1],
qkv_bias=[True, True, True],
cls_token=[False, False, True],
pos_embed=[False, False, False],
qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"],
kernel_qkv=[3, 3, 3],
padding_kv=[1, 1, 1],
Expand All @@ -134,7 +125,6 @@ def __init__(
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels
self.patch_sizes = patch_sizes
self.patch_stride = patch_stride
Expand All @@ -148,7 +138,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.qkv_bias = qkv_bias
self.cls_token = cls_token
self.pos_embed = pos_embed
self.qkv_projection_method = qkv_projection_method
self.kernel_qkv = kernel_qkv
self.padding_kv = padding_kv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Cvt checkpoints from the original repository.
"""Convert CvT checkpoints from the original repository.
URL: https://github.com/microsoft/CvT"""

Expand Down Expand Up @@ -294,17 +294,14 @@ def convert_cvt_checkpoint(cvt_file, pytorch_dump_folder):

# For depth size 13 (13 = 1+2+10)
if cvt_file.rsplit("/", 1)[-1][4:6] == "13":
config.image_size = int(cvt_file.rsplit("/", 1)[-1][7:10])
config.depth = [1, 2, 10]

# For depth size 21 (21 = 1+4+16)
elif cvt_file.rsplit("/", 1)[-1][4:6] == "21":
config.image_size = int(cvt_file.rsplit("/", 1)[-1][7:10])
config.depth = [1, 4, 16]

# For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
else:
config.image_size = 384
config.depth = [2, 2, 20]
config.num_heads = [3, 12, 16]
config.embed_dim = [192, 768, 1024]
Expand Down
39 changes: 16 additions & 23 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"

# Base docstring
_CHECKPOINT_FOR_DOC = "anugunj/cvt-13"
_CHECKPOINT_FOR_DOC = "microsoft/cvt-13"
_EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "anugunj/cvt-13"
_IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"


CVT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"anugunj/cvt-13",
"anugunj/cvt-13-384-1k",
"anugunj/cvt-13-384-22k",
"anugunj/cvt-21",
"anugunj/cvt-21-384-1k",
"anugunj/cvt-21-384-22k",
"microsoft/cvt-13",
"microsoft/cvt-13-384-1k",
"microsoft/cvt-13-384-22k",
"microsoft/cvt-21",
"microsoft/cvt-21-384-1k",
"microsoft/cvt-21-384-22k",
# See all Cvt models at https://huggingface.co/models?filter=cvt
]

Expand All @@ -78,15 +78,7 @@ class BaseModelOutputWithCLSToken(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)


# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
# Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
Expand All @@ -105,14 +97,15 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
return output


# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath
class CvtDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)


Expand All @@ -123,7 +116,7 @@ class CvtEmbeddings(nn.Module):

def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
super().__init__()
self.convolution_embeddings = ConvEmbeddings(
self.convolution_embeddings = CvtConvEmbeddings(
patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
)
self.dropout = nn.Dropout(dropout_rate)
Expand All @@ -134,14 +127,14 @@ def forward(self, pixel_values):
return hidden_state


class ConvEmbeddings(nn.Module):
class CvtConvEmbeddings(nn.Module):
"""
Image to Conv Embedding.
"""

def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
super().__init__()
patch_size = to_2tuple(patch_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.patch_size = patch_size
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
self.normalization = nn.LayerNorm(embed_dim)
Expand Down Expand Up @@ -468,7 +461,7 @@ def __init__(
dropout_rate=config.drop_rate[self.stage],
)

dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]

self.layers = nn.Sequential(
*[
Expand All @@ -484,7 +477,7 @@ def __init__(
qkv_bias=config.qkv_bias[self.stage],
attention_drop_rate=config.attention_drop_rate[self.stage],
drop_rate=config.drop_rate[self.stage],
drop_path_rate=dpr[self.stage],
drop_path_rate=drop_path_rates[self.stage],
mlp_ratio=config.mlp_ratio[self.stage],
with_cls_token=config.cls_token[self.stage],
)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cvt/test_modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch

from transformers import CvtForImageClassification, CvtModel
from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST


if is_vision_available():
Expand Down Expand Up @@ -120,7 +120,7 @@ def create_and_check_model(self, config, pixel_values, labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
image_size = to_2tuple(self.image_size)
image_size = (self.image_size, self.image_size)
height, width = image_size[0], image_size[1]
for i in range(len(self.depth)):
height = floor(((height + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1)
Expand Down
1 change: 1 addition & 0 deletions utils/documentation_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/convnext/modeling_convnext.py
src/transformers/models/ctrl/modeling_ctrl.py
src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deit/modeling_deit.py
Expand Down

0 comments on commit fe257cf

Please sign in to comment.