Skip to content

Commit

Permalink
Add models to main init
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jul 13, 2021
1 parent 773b1f6 commit 984a3e2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 22 deletions.
40 changes: 18 additions & 22 deletions src/transformers/__init__.py
Expand Up @@ -43,7 +43,6 @@
from . import dependency_versions_check
from .file_utils import (
_LazyModule,
is_detectron2_available,
is_flax_available,
is_sentencepiece_available,
is_speech_available,
Expand Down Expand Up @@ -99,7 +98,6 @@
"cached_path",
"is_apex_available",
"is_datasets_available",
"is_detectron2_available",
"is_faiss_available",
"is_flax_available",
"is_psutil_available",
Expand Down Expand Up @@ -442,25 +440,6 @@
name for name in dir(dummy_timm_objects) if not name.startswith("_")
]

# Detectron2-backed objects
if is_detectron2_available():
_import_structure["models.layoutlmv2"].extend(
[
"LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMv2ForQuestionAnswering",
"LayoutLMv2ForSequenceClassification",
"LayoutLMv2ForTokenClassification",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
]
)
else:
from .utils import dummy_detectron2_objects

_import_structure["utils.dummy_detectron2_objects"] = [
name for name in dir(dummy_detectron2_objects) if not name.startswith("_")
]

# PyTorch-backed objects
if is_torch_available():
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
Expand Down Expand Up @@ -850,6 +829,16 @@
"LayoutLMPreTrainedModel",
]
)
_import_structure["models.layoutlmv2"].extend(
[
"LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMv2ForQuestionAnswering",
"LayoutLMv2ForSequenceClassification",
"LayoutLMv2ForTokenClassification",
"LayoutLMv2Model",
"LayoutLMv2PreTrainedModel",
]
)
_import_structure["models.led"].extend(
[
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1750,7 +1739,6 @@
cached_path,
is_apex_available,
is_datasets_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
is_psutil_available,
Expand Down Expand Up @@ -2387,6 +2375,14 @@
LayoutLMModel,
LayoutLMPreTrainedModel,
)
from .models.layoutlmv2 import (
LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification,
LayoutLMv2ForTokenClassification,
LayoutLMv2Model,
LayoutLMv2PreTrainedModel,
)
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
Expand Down
48 changes: 48 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Expand Up @@ -1960,6 +1960,54 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST = None


class LayoutLMv2ForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LayoutLMv2ForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LayoutLMv2ForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LayoutLMv2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LayoutLMv2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


LED_PRETRAINED_MODEL_ARCHIVE_LIST = None


Expand Down

0 comments on commit 984a3e2

Please sign in to comment.