Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time series forecasting model #17965

Merged
merged 172 commits into from Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
172 commits
Select commit Hold shift + click to select a range
2ebd575
initial files
kashif Jun 30, 2022
cc0b364
initial model via cli
kashif Jun 30, 2022
cf60b95
typos
kashif Jul 1, 2022
b18d345
make a start on the model config
kashif Jul 1, 2022
eb328f5
ready with configuation
kashif Jul 1, 2022
8e664f4
remove tokenizer ref.
kashif Jul 1, 2022
e078e09
init the transformer
kashif Jul 1, 2022
972cc89
added initial model forward to return dec_output
kashif Jul 1, 2022
cf74608
require gluonts
kashif Jul 3, 2022
c7b8158
update dep. ver table and add as extra
kashif Jul 3, 2022
dba022c
fixed typo
kashif Jul 3, 2022
44c9944
Merge remote-tracking branch 'upstream/main' into time-series
kashif Jul 4, 2022
29bcc72
add type for prediction_length
kashif Jul 4, 2022
6f3f627
use num_time_features
kashif Jul 4, 2022
97fbcf4
use config
kashif Jul 4, 2022
62286aa
more config
kashif Jul 4, 2022
88e071b
typos
kashif Jul 4, 2022
3e01a69
opps another typo
kashif Jul 4, 2022
202dab0
freq can be none
kashif Jul 5, 2022
489d7f1
default via transformation is 1
kashif Jul 5, 2022
d5f6eb3
initial transformations
kashif Jul 5, 2022
302a387
fix imports
kashif Jul 5, 2022
669aeeb
added transform_start_field
kashif Jul 5, 2022
6c339d8
add helper to create pytorch dataloader
kashif Jul 6, 2022
43ffcf1
added inital val and test data loader
kashif Jul 7, 2022
1525a1a
added initial distr head and loss
kashif Jul 7, 2022
3fa6d48
training working
kashif Jul 7, 2022
956f369
remove TimeSeriesTransformerTokenizer
kashif Jul 8, 2022
9322bb1
Update src/transformers/__init__.py
kashif Jul 8, 2022
2babbe9
Update src/transformers/models/time_series_transformer/__init__.py
kashif Jul 8, 2022
1dc5e85
fixed copyright
kashif Jul 11, 2022
7d6732f
removed docs
kashif Jul 11, 2022
e113de5
remove time series tokenizer
kashif Jul 11, 2022
800b452
fixed docs
kashif Jul 12, 2022
16109a7
fix text
kashif Jul 12, 2022
bcff605
fix second
kashif Jul 12, 2022
cce6973
fix default
kashif Jul 12, 2022
216c206
fix order
kashif Jul 13, 2022
8dae3cf
use config directly
kashif Jul 13, 2022
9afbac2
undo change
kashif Jul 13, 2022
2c06f97
fix comment
kashif Jul 13, 2022
050b8b5
fix year
kashif Jul 13, 2022
d086185
fix import
kashif Jul 15, 2022
dc2acd4
add additional arguments for training vs. test
kashif Jul 15, 2022
758b90d
initial greedy inference loop
kashif Jul 15, 2022
c9f8540
fix inference
kashif Jul 15, 2022
97d6743
comment out token inputs to enc dec
kashif Jul 21, 2022
b512c50
Use HF encoder/decoder
kashif Aug 7, 2022
7f332f2
fix inference
kashif Aug 7, 2022
3bce7cc
Use Seq2SeqTSModelOutput output
kashif Aug 9, 2022
3577729
return Seq2SeqTSPredictionOutput
kashif Aug 10, 2022
8f7d803
added default arguments
kashif Aug 10, 2022
c90c126
fix return_dict true
kashif Aug 10, 2022
63d709a
scale is a tensor
kashif Aug 10, 2022
bcd566f
output static_features for inference
kashif Aug 12, 2022
6c6a57b
clean up some unused bits
kashif Aug 12, 2022
5346712
fixed typo
kashif Aug 12, 2022
6ef3a59
set return_dict if none
kashif Aug 12, 2022
ec4ab0f
call model once for both train/predict
kashif Aug 12, 2022
87a8f1b
use cache if future_target is none
kashif Aug 12, 2022
01343f7
initial generate func
kashif Aug 14, 2022
be1f132
generate arguments
kashif Aug 15, 2022
3cbfd7b
future_time_feat is required
kashif Aug 15, 2022
89e9279
return SampleTSPredictionOutput
kashif Aug 17, 2022
686db78
removed unneeded classes
kashif Aug 18, 2022
336fb95
fix when params is none
kashif Aug 18, 2022
4623d2d
fix return dict
kashif Aug 18, 2022
035f016
fix num_attention_heads
kashif Aug 18, 2022
8e2c2fa
fix arguments
kashif Aug 18, 2022
ccdf048
remove unused shift_tokens_right
kashif Aug 18, 2022
9cbb500
add different dropout configs
kashif Aug 19, 2022
996911a
implement FeatureEmbedder, Scaler and weighted_average
kashif Sep 3, 2022
54460c7
remove gluonts dependency
kashif Sep 5, 2022
21dc3cd
fix class names
kashif Sep 5, 2022
9763a3e
avoid _variable names
kashif Sep 5, 2022
3dab2e1
remove gluonts dependency
kashif Sep 6, 2022
27f1349
fix imports
kashif Sep 6, 2022
72f9a27
remove gluonts from configuration
kashif Sep 6, 2022
949b82a
fix docs
kashif Sep 6, 2022
2eaf401
fixed typo
kashif Sep 8, 2022
924f1bd
move utils to examples
kashif Sep 8, 2022
4b955b4
add example requirements
kashif Sep 8, 2022
5dfa7c7
config has no freq
kashif Sep 8, 2022
7e02a57
initial run_ts_no_trainer
kashif Sep 12, 2022
1b166ab
remove from ignore
kashif Sep 22, 2022
de106a9
Merge remote-tracking branch 'upstream/main' into time-series
kashif Sep 22, 2022
727982a
fix output_attentions and removed unsued getters/setters
kashif Sep 22, 2022
84c77ad
removed unsed tests
kashif Sep 22, 2022
06679dd
add dec seq len
kashif Sep 22, 2022
8656275
add test_attention_outputs
kashif Sep 22, 2022
f4465b7
set has_text_modality=False
kashif Sep 22, 2022
93d0289
add config attribute_map
kashif Sep 22, 2022
19f188f
make style
kashif Sep 23, 2022
444ba89
make fix-copies
kashif Sep 23, 2022
d639388
add encoder_outputs to TimeSeriesTransformerForPrediction forward
kashif Sep 23, 2022
8cbe681
Improve docs, add model to README
NielsRogge Sep 23, 2022
c658f03
added test_forward_signature
kashif Sep 23, 2022
13de9c6
More improvements
NielsRogge Sep 23, 2022
649373f
Add more copied from
NielsRogge Sep 23, 2022
08d81f5
Fix README
NielsRogge Sep 23, 2022
390db66
Fix remaining quality issues
NielsRogge Sep 23, 2022
ca37eda
Merge branch 'time-series' into time-series
kashif Sep 23, 2022
ff878eb
Merge pull request #1 from NielsRogge/time-series
kashif Sep 23, 2022
973baf3
updated encoder and decoder
kashif Sep 23, 2022
ea138ff
fix generate
kashif Sep 23, 2022
9076757
output_hidden_states and use_cache are optional
kashif Sep 23, 2022
409a88a
past key_values returned too
kashif Sep 23, 2022
ed6b62a
initialize weights of distribution_output module
kashif Sep 23, 2022
1ef37e4
fixed more tests
kashif Sep 23, 2022
b6b6be3
update test_forward_signature
kashif Sep 26, 2022
0e7d3ed
fix return_dict outputs
kashif Sep 26, 2022
eb3ef8a
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 26, 2022
ab05cf5
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 26, 2022
60a63bf
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 26, 2022
e8663f0
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 26, 2022
4b46121
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 26, 2022
3ab0bfd
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 26, 2022
ac3e8d8
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 26, 2022
9d9ed9a
removed commented out tests
kashif Sep 27, 2022
d631896
added neg. bin and normal output
kashif Sep 27, 2022
1618126
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 27, 2022
70cfbce
move to one line
kashif Sep 27, 2022
678f297
Add docstrings
NielsRogge Sep 27, 2022
4d6bd2b
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 27, 2022
7174391
add try except for assert and raise
kashif Sep 27, 2022
58d368b
try and raise exception
kashif Sep 27, 2022
319011b
fix the documentation formatting
kashif Sep 27, 2022
ecbf682
fix assert call
kashif Sep 27, 2022
74fd969
fix docstring formatting
kashif Sep 27, 2022
21d2ba9
removed input_ids from DOCSTRING
kashif Sep 27, 2022
d995f4b
Update input docstring
NielsRogge Sep 27, 2022
9d8914f
Improve variable names
NielsRogge Sep 27, 2022
e995a3f
Merge branch 'time-series' into time-series
NielsRogge Sep 28, 2022
bbf49a3
Update order of inputs
NielsRogge Sep 28, 2022
1358e79
Merge branch 'time-series' of github.com:NielsRogge/transformers into…
NielsRogge Sep 28, 2022
f44aad5
Improve configuration
NielsRogge Sep 28, 2022
4e1d1f6
Improve variable names
NielsRogge Sep 28, 2022
6ceed7c
Improve docs
NielsRogge Sep 28, 2022
b7ce766
Remove key_length from tests
NielsRogge Sep 28, 2022
bac12b9
Add extra docs
NielsRogge Sep 28, 2022
62194a5
Merge pull request #2 from NielsRogge/time-series
kashif Sep 28, 2022
de88260
initial unittests
kashif Sep 28, 2022
4f858f3
added test_inference_no_head test
kashif Sep 28, 2022
3d05dad
added test_inference_head
kashif Sep 28, 2022
93ff659
add test_seq_to_seq_generation
kashif Sep 28, 2022
2339a53
make style
kashif Sep 28, 2022
4c05158
one line
kashif Sep 28, 2022
0df5859
assert mean prediction
kashif Sep 28, 2022
d31c0de
removed comments
kashif Sep 28, 2022
0827fbe
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 29, 2022
c8b56cc
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 29, 2022
a2043c1
fix order of args
kashif Sep 29, 2022
bc5aaa9
Merge remote-tracking branch 'upstream/main' into time-series
kashif Sep 29, 2022
a895801
make past_observed_mask optional as well
kashif Sep 30, 2022
6840145
added Amazon license header
kashif Sep 30, 2022
2ec916e
updated utils with new fieldnames
kashif Sep 30, 2022
ffb319f
make style
kashif Sep 30, 2022
6649b29
cleanup
kashif Sep 30, 2022
006a802
undo position of past_observed_mask
kashif Sep 30, 2022
fcd48cd
fix import
kashif Sep 30, 2022
4ebc370
typo
kashif Sep 30, 2022
2ab53bc
more typo
kashif Sep 30, 2022
78cd493
rename example files
kashif Sep 30, 2022
4a7f8c9
remove example for now
kashif Sep 30, 2022
cb66357
Update docs/source/en/_toctree.yml
kashif Sep 30, 2022
4fe63fe
Update src/transformers/models/time_series_transformer/configuration_…
kashif Sep 30, 2022
b28a2f9
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 30, 2022
8a3e430
Update src/transformers/models/time_series_transformer/modeling_time_…
kashif Sep 30, 2022
f7b0ab5
Update modeling_time_series_transformer.py
kashif Sep 30, 2022
4fe8164
fixed typo
kashif Sep 30, 2022
48a2be4
fix typo and grammer
kashif Sep 30, 2022
43dd269
fix style
kashif Sep 30, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/source/en/model_doc/time_series_transformer.mdx
@@ -0,0 +1,77 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# TimeSeriesTransformer

## Overview

The TimeSeriesTransformer model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>. <INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](<https://huggingface.co/<INSERT YOUR HF USERNAME HERE>). The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved

## TimeSeriesTransformerConfig

[[autodoc]] TimeSeriesTransformerConfig


## TimeSeriesTransformerTokenizer

[[autodoc]] TimeSeriesTransformerTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary


## TimeSeriesTransformerTokenizerFast

[[autodoc]] TimeSeriesTransformerTokenizerFast
kashif marked this conversation as resolved.
Show resolved Hide resolved


## TimeSeriesTransformerModel

[[autodoc]] TimeSeriesTransformerModel
- forward
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved


## TimeSeriesTransformerForConditionalGeneration

[[autodoc]] TimeSeriesTransformerForConditionalGeneration
- forward


## TimeSeriesTransformerForSequenceClassification

[[autodoc]] TimeSeriesTransformerForSequenceClassification
- forward


## TimeSeriesTransformerForQuestionAnswering

[[autodoc]] TimeSeriesTransformerForQuestionAnswering
- forward


## TimeSeriesTransformerForCausalLM

[[autodoc]] TimeSeriesTransformerForCausalLM
- forward
kashif marked this conversation as resolved.
Show resolved Hide resolved


1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -15,6 +15,7 @@ known_third_party =
fire
fugashi
git
gluonts
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
h5py
matplotlib
nltk
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Expand Up @@ -114,6 +114,7 @@
"ftfy",
"fugashi>=1.0",
"GitPython<3.1.19",
"gluonts>=0.10.0",
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.1.0,<1.0",
"importlib_metadata",
Expand Down Expand Up @@ -277,6 +278,8 @@ def run(self):
extras["vision"] = deps_list("Pillow")
extras["timm"] = deps_list("timm")
extras["codecarbon"] = deps_list("codecarbon")
extras["gluonts"] = deps_list("gluonts")
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) normally this file shouldn't be touched by the PR

extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
Expand Down Expand Up @@ -319,6 +322,7 @@ def run(self):
+ extras["timm"]
+ extras["codecarbon"]
+ extras["accelerate"]
+ extras["gluonts"]
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
)

# Might need to add doc-builder and some specific deps in the future
Expand All @@ -343,6 +347,7 @@ def run(self):
+ extras["sklearn"]
+ extras["modelcreation"]
+ extras["onnxruntime"]
+ extras["gluonts"]
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
)
extras["dev-tensorflow"] = (
extras["testing"]
Expand Down
39 changes: 37 additions & 2 deletions src/transformers/__init__.py
Expand Up @@ -308,6 +308,11 @@
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
"models.tapex": ["TapexTokenizer"],
"models.time_series_transformer": [
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig",
"TimeSeriesTransformerTokenizer",
kashif marked this conversation as resolved.
Show resolved Hide resolved
],
"models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TrajectoryTransformerConfig",
Expand Down Expand Up @@ -505,6 +510,7 @@
]
else:
# Fast tokenizers structure
_import_structure["models.time_series_transformer"].append("TimeSeriesTransformerTokenizerFast")
kashif marked this conversation as resolved.
Show resolved Hide resolved
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
_import_structure["models.barthez"].append("BarthezTokenizerFast")
Expand Down Expand Up @@ -750,6 +756,19 @@
_import_structure["modeling_utils"] = ["PreTrainedModel"]

# PyTorch models structure

_import_structure["models.time_series_transformer"].extend(
[
"TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimeSeriesTransformerForCausalLM",
"TimeSeriesTransformerForConditionalGeneration",
"TimeSeriesTransformerForQuestionAnswering",
"TimeSeriesTransformerForSequenceClassification",
kashif marked this conversation as resolved.
Show resolved Hide resolved
"TimeSeriesTransformerForPrediction",
"TimeSeriesTransformerModel",
"TimeSeriesTransformerPreTrainedModel",
]
)
_import_structure["models.albert"].extend(
[
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -3013,6 +3032,11 @@
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
from .models.tapex import TapexTokenizer
from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig,
TimeSeriesTransformerTokenizer,
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
)
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrajectoryTransformerConfig,
Expand Down Expand Up @@ -3235,6 +3259,7 @@
from .models.splinter import SplinterTokenizerFast
from .models.squeezebert import SqueezeBertTokenizerFast
from .models.t5 import T5TokenizerFast
from .models.time_series_transformer import TimeSeriesTransformerTokenizerFast
kashif marked this conversation as resolved.
Show resolved Hide resolved
from .models.xglm import XGLMTokenizerFast
from .models.xlm_roberta import XLMRobertaTokenizerFast
from .models.xlnet import XLNetTokenizerFast
Expand Down Expand Up @@ -3385,8 +3410,6 @@
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel

# PyTorch model imports
kashif marked this conversation as resolved.
Show resolved Hide resolved
from .models.albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
Expand Down Expand Up @@ -4198,6 +4221,18 @@
T5PreTrainedModel,
load_tf_weights_in_t5,
)

# PyTorch model imports
kashif marked this conversation as resolved.
Show resolved Hide resolved
from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimeSeriesTransformerForCausalLM,
TimeSeriesTransformerForConditionalGeneration,
TimeSeriesTransformerForQuestionAnswering,
TimeSeriesTransformerForSequenceClassification,
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel,
)
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TrajectoryTransformerModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Expand Up @@ -20,6 +20,7 @@
"ftfy": "ftfy",
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"gluonts": "gluonts>=0.10.0",
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Expand Up @@ -128,6 +128,7 @@
t5,
tapas,
tapex,
time_series_transformer,
trajectory_transformer,
transfo_xl,
trocr,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Expand Up @@ -29,6 +29,7 @@
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("time_series_transformer", "TimeSeriesTransformerConfig"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "AlbertConfig"),
("bart", "BartConfig"),
("beit", "BeitConfig"),
Expand Down Expand Up @@ -153,6 +154,7 @@
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here)
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -262,6 +264,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("time_series_transformer", "TimeSeriesTransformer"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "ALBERT"),
("bart", "BART"),
("barthez", "BARThez"),
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Expand Up @@ -28,6 +28,7 @@
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("time_series_transformer", "TimeSeriesTransformerModel"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "AlbertModel"),
("bart", "BartModel"),
("beit", "BeitModel"),
Expand Down Expand Up @@ -197,6 +198,8 @@
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping

("time_series_transformer", "TimeSeriesTransformerForConditionalGeneration"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("bert", "BertForMaskedLM"),
Expand Down Expand Up @@ -259,6 +262,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("time_series_transformer", "TimeSeriesTransformerForCausalLM"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
Expand Down Expand Up @@ -434,6 +438,8 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping

("time_series_transformer", "TimeSeriesTransformerForConditionalGeneration"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("bart", "BartForConditionalGeneration"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
Expand Down Expand Up @@ -465,6 +471,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("time_series_transformer", "TimeSeriesTransformerForSequenceClassification"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "AlbertForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("bert", "BertForSequenceClassification"),
Expand Down Expand Up @@ -521,6 +528,7 @@
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("time_series_transformer", "TimeSeriesTransformerForQuestionAnswering"),
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
("albert", "AlbertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
Expand Down
85 changes: 85 additions & 0 deletions src/transformers/models/time_series_transformer/__init__.py
@@ -0,0 +1,85 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The HuggingFace Team. All rights reserved.
kashif marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
kashif marked this conversation as resolved.
Show resolved Hide resolved
from ...utils import is_torch_available


_import_structure = {
"configuration_time_series_transformer": [
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig",
],
"tokenization_time_series_transformer": ["TimeSeriesTransformerTokenizer"],
kashif marked this conversation as resolved.
Show resolved Hide resolved
}

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_time_series_transformer_fast"] = ["TimeSeriesTransformerTokenizerFast"]
kashif marked this conversation as resolved.
Show resolved Hide resolved

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_time_series_transformer"] = [
"TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimeSeriesTransformerForConditionalGeneration",
"TimeSeriesTransformerForQuestionAnswering",
"TimeSeriesTransformerForSequenceClassification",
"TimeSeriesTransformerForCausalLM",
kashif marked this conversation as resolved.
Show resolved Hide resolved
"TimeSeriesTransformerForPrediction",
"TimeSeriesTransformerModel",
"TimeSeriesTransformerPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig,
)

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimeSeriesTransformerForConditionalGeneration,
TimeSeriesTransformerForCausalLM,
TimeSeriesTransformerForQuestionAnswering,
TimeSeriesTransformerForSequenceClassification,
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)