Skip to content

Commit

Permalink
Move import of ASTFeatureExtractor under a is_speech_available
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Nov 15, 2022
1 parent bae1f3c commit 5b4eea6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Expand Up @@ -672,6 +672,7 @@
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
else:
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
_import_structure["models.mctct"].append("MCTCTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")

Expand Down Expand Up @@ -857,7 +858,6 @@
"ASTForSequenceClassification",
"ASTModel",
"ASTPreTrainedModel",
"ASTFeatureExtractor",
]
)
_import_structure["models.albert"].extend(
Expand Down Expand Up @@ -3760,6 +3760,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
else:
from .models.audio_spectrogram_transformer import ASTFeatureExtractor
from .models.mctct import MCTCTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor

Expand Down Expand Up @@ -3917,7 +3918,6 @@
# PyTorch model imports
from .models.audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTFeatureExtractor,
ASTForSequenceClassification,
ASTModel,
ASTPreTrainedModel,
Expand Down
20 changes: 17 additions & 3 deletions src/transformers/models/audio_spectrogram_transformer/__init__.py
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available


_import_structure = {
Expand All @@ -33,14 +33,21 @@
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]
_import_structure["modeling_audio_spectrogram_transformer"] = [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ASTForSequenceClassification",
"ASTModel",
"ASTPreTrainedModel",
]

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]

if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand All @@ -53,14 +60,21 @@
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
from .modeling_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTForSequenceClassification,
ASTModel,
ASTPreTrainedModel,
)

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor


else:
import sys
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_pt_objects.py
Expand Up @@ -353,13 +353,6 @@ def load_tf_weights_in_albert(*args, **kwargs):
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None


class ASTFeatureExtractor(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class ASTForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_speech_objects.py
Expand Up @@ -3,6 +3,13 @@
from ..utils import DummyObject, requires_backends


class ASTFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"])


class MCTCTFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]

Expand Down

0 comments on commit 5b4eea6

Please sign in to comment.