-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] pytorch forecasting adapter with Global Forecasting API #6228
base: main
Are you sure you want to change the base?
[ENH] pytorch forecasting adapter with Global Forecasting API #6228
Conversation
Just a general comment. I would propose to split this into multiple PRs. This would make it easier to review. I would propose a PR for the pytorch-forecasting adapter (first PR) and a second PR that introduces the global forecasting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just commenting - very interesting!
I suppoes a separate class will give us opportunity to develop this capability. Ultimately, we may decide to merge it into BaseForecaster
, or not.
Yeah, splitting this into 2 PRs will make the workflow more clear. But I am taking pytorch-forecasting as a experiment to try the global forecasting api design, ultimately DL models will need gloabl forecastig api anyway, would it be more convenient to have a single PR? |
Good question, @Xinyu-Wu-0000 - as long as this is experimental, it's up tp you. It's always good to have one or two examples working for new API development, so it makes sense ot have examples in the PR. Although there might be substantial challenges in isolation coming from pytorch forecasting which do not have to do with the global forecasting extension (you already list many of the serious ones, e.g., loader, prediction object), so I wonder if there is a simpler example to develop around. Either way, that's not a strict requirement, as long as you are working in an exploratory sense. |
Maybe NeuralForecast could be a simpler example as it's already been interfaced and all models from NeuralForecast are capable of global forecasting, but several PRs are currently working on NeuralForecast. I choose pytorch-forecasting to minimize the impact on existing code base as extending global forecasting API will be a quite big change.
I just made it work for an example from pytorch-forecasting. It is the first tutorial in the document of pytorch-forecasting. By the way, are we going to have a release with partial global forecasting API support? Something like version 0.30, only NeuralForecast models and pytorch-forecasting models with global forecasting API. |
Yes, I think that's a valid upgrade plan, e.g., release first only some forecasters, and then later merge base classes if everything is robust. It could be 0.29.0 even in theory, because we're not impacting existing classes with your plan. |
Another wired issue about pytorch-forecasting models. If the As dicussed in the mentor meeting, a possilble workaround would be passing Details of the error in colab notebook: https://colab.research.google.com/drive/16ft4Prqe1pEmLHgz1CcFnTZf4kA9kFX5?usp=sharing import numpy as np
import pandas as pd
from lightning.pytorch import Trainer
from pytorch_forecasting import (
DeepAR,
NBeats,
NHiTS,
TemporalFusionTransformer,
TimeSeriesDataSet,
)
# model to test
model_class = NHiTS
model_class = TemporalFusionTransformer
model_class = DeepAR
#model_class = NBeats
# set the max_prediction_length parameter
max_prediction_length = 1
n_timeseries = 100
time_points = 100
data = pd.DataFrame(
data={
"target": np.random.rand(time_points * n_timeseries),
"time_varying_known_real_1": np.random.rand(time_points * n_timeseries),
"time_idx": np.tile(np.arange(time_points), n_timeseries),
"group_id": np.repeat(np.arange(n_timeseries), time_points),
}
)
print(data)
training_dataset = TimeSeriesDataSet(
data=data,
time_idx="time_idx",
target="target",
group_ids=["group_id"],
time_varying_unknown_reals=["target"],
time_varying_known_reals=(
["time_varying_known_real_1"] if model_class != NBeats else []
),
max_prediction_length=max_prediction_length,
)
validation_dataset = TimeSeriesDataSet.from_dataset(
training_dataset, data, stop_randomization=True, predict=True
)
training_data_loader = training_dataset.to_dataloader(train=True)
validation_data_loader = validation_dataset.to_dataloader(train=False)
forecaster = model_class.from_dataset(training_dataset, log_val_interval=1)
pytorch_trainer = Trainer(
accelerator="cpu",
max_epochs=3,
min_epochs=2,
limit_train_batches=10,
)
pytorch_trainer.fit(
forecaster,
train_dataloaders=training_data_loader,
val_dataloaders=validation_data_loader,
) |
Apologies if I am overlooking something: the colab shows that error is coming from |
@yarnabrina Yes, you are right, the error is from plotting during fitting. It might be a bug with pytorch-forecasting and I have opened a issue on their repository jdb78/pytorch-forecasting#1571. It is different for NHiTS though. Maybe NHiTS doesn't support max_prediction_length=1. |
The bug happens when the model try to log the validation, therefor, setting log_val_interval to -1 could avoid the error in CI. However users will still encouter the error if they try to log validation with max_prediction_length=1.
resolve conflics in pyproject.toml
We should not be plotting while we are fitting. The upstream libraries may choose to do so, but it would be a bad design choice to do so by default. It should be turned off by default, and hopefully that will (or has already?) resolved the error. |
The plotting is in the logging validataion step, Therefor I turn it off in the |
@benHeid @fkiraly @yarnabrina I think it's ready for a review now. There are 2 new forecasters |
Added a quick fix for a bug discovered by @benHeid during the mentoring meeting. |
kindly ignore the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Xinyu-Wu-0000 thank you for the PR. I think it is very nice. However, there are a few things we need to discuss. And I am also not completely done with my Review. But I hope that these comments help you to continue to work on this feature.
gf = self.get_tag( | ||
"capability:global_forecasting", tag_value_default=False, raise_error=False | ||
) | ||
if gf is not True and y is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternative: not gf and y is not None
self._trainer_params = _none_check(self.trainer_params, {}) | ||
import lightning.pytorch as pl | ||
|
||
traner_instance = pl.Trainer(**self._trainer_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trainer_instance
self._trainer_params = _none_check(self.trainer_params, {}) | ||
import lightning.pytorch as pl | ||
|
||
traner_instance = pl.Trainer(**self._trainer_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can you access self._y_trainer_params. In the init only self.y_trainer_params is set.
reference to self | ||
""" | ||
self._dataset_params = _none_check(self.dataset_params, {}) | ||
self._max_prediction_length = fh.to_relative(self.cutoff)[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternative: max(fh.to_relative...)
_y, self._convert_to_series = _series_to_frame(y) | ||
_X, _ = _series_to_frame(X) | ||
# store the target column names and index names (probably [None]) | ||
# will be renamed ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean with this comment? This comment is confusing me.
y_test = y_train.iloc[:-max_prediction_length] | ||
y_pred = estimator_instance.predict(fh, y=y_test) | ||
|
||
# TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a todo
) | ||
y_pred = estimator_instance.predict(fh, y=y_test) | ||
|
||
# TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a todo
@@ -892,3 +898,186 @@ def test_fit_predict(self, estimator_instance, n_columns): | |||
cutoff = get_cutoff(y_train, return_index=True) | |||
_assert_correct_pred_time_index(y_pred.index, cutoff, fh) | |||
_assert_correct_columns(y_pred, y_train) | |||
|
|||
|
|||
class TestAllGlobalForecasters(TestAllObjects): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if I am wrong. If I have seen it correctly, you passed a copy of y_train to the predict methods. I am wondering if we can make this more general:
- E.g. renaming the columns of y_test so that they are different from y_train and check if the current column names are returned.
- Provide a different y_test than the y_train to the test and check if the results of both are different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fkiraly maybe you have a comment on that
.. [2] https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.html | ||
""" # noqa: E501 | ||
|
||
_tags = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose some of these models are supporting probabilistic forecasting. Please check this. It is okay if you add the support later, but we should track this in this case via an issue
@@ -2540,6 +2540,132 @@ def _get_columns(self, method="predict", **kwargs): | |||
BaseForecaster._init_dynamic_doc() | |||
|
|||
|
|||
class BaseGlobalForecaster(BaseForecaster): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose that the BaseGlobalForecaster should in general also be able to perform probabilistic forecasts. It is okay to add this not directly in this PR I suppose. But this requires tracking of this via an additional issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fkiraly any comments on this?
Reference Issues/PRs
Related: #4651, #4923
Main Topic
A pytorch forecasting adapter with Global Forecasting API and several algorithms for design validation.
Details
I'm developing a pytorch forecasting adapter with the Global Forecasting API. To ensure a well-designed implementation, I'd like to discuss some design aspects.
New Base Class for Minimal Impact
A new base class,
GlobalBaseForecaster
, has been added to minimize the impact on existing forecasters and simplify testing. As discussed in #4651, the plan is to manage the Global Forecasting API via tags only. However, a phased approach might be beneficial. If a tag-based approach is confirmed, we can mergeGlobalBaseForecaster
back intoBaseForecaster
after design validation.Data Type Conversion Challenges
Data type conversion presents a challenge because PyTorch forecasting expects TimeSeriesDataSet as input. While
TimeSeriesDataSet
can be created from apandas.DataFrame
, it requires numerous parameters. Determining where to pass these parameters is a key question.Placing them in
fit
would introduce inconsistency with the existing API. If we put them in__init__
, it would be very counterintuitive to define how the data conversion works while initializing the algorithm.A similar issue arises during trainer initialization. Currently,
trainer_params: Dict[str, Any]
is used within__init__
to obtain trainer initialization parameters. However, the API for passing these parameters totrainer.fit
is yet to be designed.To convert
pytorch_forecasting.models.base_model.Prediction
back to apandas.DataFrame
, a custom conversion method is required. Refer to the following issues for more information: jdb78/pytorch-forecasting#734, jdb78/pytorch-forecasting#177.Train/Validation Strategy
Training a model in PyTorch forecasting necessitates passing both the training and validation datasets together to the training algorithm. This allows for monitoring training progress, adjusting the learning rate, saving the model, or even stopping training prematurely. This differs from the typical sktime approach where only the training data is passed to fit and the test data is used for validation after training. Any suggestions on how to best address this discrepancy?
@benHeid @fkiraly Thank you very much for the feedback on my GSoC proposal! Any suggestions on implementation details or the overall design would be greatly appreciated.