Skip to content

Commit

Permalink
Replace dummies by requires_backends
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jul 13, 2021
1 parent c3fd865 commit d655367
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/layoutlmv2/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_detectron2_available, is_tokenizers_available
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available


_import_structure = {
Expand All @@ -29,7 +29,7 @@
if is_tokenizers_available():
_import_structure["tokenization_layoutlmv2_fast"] = ["LayoutLMv2TokenizerFast"]

if is_detectron2_available():
if is_torch_available():
_import_structure["modeling_layoutlmv2"] = [
"LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMv2ForQuestionAnswering",
Expand All @@ -47,7 +47,7 @@
if is_tokenizers_available():
from .tokenization_layoutlmv2_fast import LayoutLMv2TokenizerFast

if is_detectron2_available():
if is_torch_available():
from .modeling_layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Expand Up @@ -28,6 +28,7 @@
add_start_docstrings_to_model_forward,
is_detectron2_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import (
BaseModelOutput,
Expand All @@ -41,6 +42,7 @@
from .configuration_layoutlmv2 import LayoutLMv2Config


# soft dependency
if is_detectron2_available():
import detectron2
from detectron2.modeling import META_ARCH_REGISTRY
Expand Down Expand Up @@ -683,7 +685,8 @@ def forward(self, hidden_states):
)
class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
def __init__(self, config):
super(LayoutLMv2Model, self).__init__(config)
requires_backends(self, "detectron2")
super().__init__(config)
self.config = config
self.has_visual_segment_embedding = config.has_visual_segment_embedding
self.embeddings = LayoutLMv2Embeddings(config)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Expand Up @@ -31,6 +31,7 @@
from .deepspeed import is_deepspeed_available
from .file_utils import (
is_datasets_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
is_keras2onnx_available,
Expand Down Expand Up @@ -456,6 +457,14 @@ def require_datasets(test_case):
return test_case


def require_detectron2(test_case):
"""Decorator marking a test that requires detectron2."""
if not is_detectron2_available():
return unittest.skip("test requires `detectron2`")(test_case)
else:
return test_case


def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not is_faiss_available():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modeling_layoutlmv2.py
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers.file_utils import is_detectron2_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_detectron2, require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand All @@ -36,7 +36,6 @@
)
from transformers.models.layoutlmv2.modeling_layoutlmv2 import LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST


if is_detectron2_available():
from detectron2.structures.image_list import ImageList

Expand Down Expand Up @@ -247,6 +246,7 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
@require_detectron2
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):

test_pruning = False
Expand Down

0 comments on commit d655367

Please sign in to comment.