Skip to content

Commit

Permalink
[WIP] Disentangle auto modules from other modeling files (#13023)
Browse files Browse the repository at this point in the history
* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
  • Loading branch information
sgugger committed Aug 6, 2021
1 parent 2e40823 commit 9870093
Show file tree
Hide file tree
Showing 26 changed files with 1,333 additions and 2,400 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/model-templates.yml
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Run style changes
run: |
git fetch origin master:master
make fixup
make style && make quality
- name: Failure short reports
if: ${{ always() }}
Expand Down
1 change: 0 additions & 1 deletion Makefile
Expand Up @@ -30,7 +30,6 @@ deps_table_check_updated:
# autogenerating code

autogenerate_code: deps_table_update
python utils/class_mapping_update.py

# Check that source code meets quality standards

Expand Down
8 changes: 5 additions & 3 deletions src/transformers/__init__.py
Expand Up @@ -213,6 +213,7 @@
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
Expand Down Expand Up @@ -315,7 +316,7 @@
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
Expand Down Expand Up @@ -358,7 +359,7 @@
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast")
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast")
Expand Down Expand Up @@ -2021,7 +2022,8 @@
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modelcard.py
Expand Up @@ -41,9 +41,7 @@
is_tokenizers_available,
is_torch_available,
)
from .training_args import ParallelMode
from .utils import logging
from .utils.modeling_auto_mapping import (
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
Expand All @@ -54,6 +52,8 @@
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import logging


TASK_MAPPING = {
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/__init__.py
Expand Up @@ -37,6 +37,7 @@
cpm,
ctrl,
deberta,
deberta_v2,
deit,
detr,
dialogpt,
Expand All @@ -50,6 +51,8 @@
gpt2,
gpt_neo,
herbert,
hubert,
ibert,
layoutlm,
led,
longformer,
Expand All @@ -58,6 +61,7 @@
m2m_100,
marian,
mbart,
mbart50,
megatron_bert,
mmbt,
mobilebert,
Expand All @@ -82,6 +86,7 @@
vit,
wav2vec2,
xlm,
xlm_prophetnet,
xlm_roberta,
xlnet,
)
84 changes: 81 additions & 3 deletions src/transformers/models/auto/auto_factory.py
Expand Up @@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict

from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)

if name.startswith("TF"):
Expand All @@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls

Expand All @@ -445,3 +447,79 @@ def get_values(model_mapping):
result.append(model)

return result


def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""

def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}

def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)

def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)

def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]

def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]

def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]

def __iter__(self):
return iter(self._mapping.keys())

def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping

0 comments on commit 9870093

Please sign in to comment.