Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] pytorch forecasting adapter with Global Forecasting API #6228

Open
wants to merge 68 commits into
base: main
Choose a base branch
from

Conversation

Xinyu-Wu-0000
Copy link
Contributor

Reference Issues/PRs

Related: #4651, #4923

Main Topic

A pytorch forecasting adapter with Global Forecasting API and several algorithms for design validation.

Details

I'm developing a pytorch forecasting adapter with the Global Forecasting API. To ensure a well-designed implementation, I'd like to discuss some design aspects.

New Base Class for Minimal Impact

A new base class, GlobalBaseForecaster, has been added to minimize the impact on existing forecasters and simplify testing. As discussed in #4651, the plan is to manage the Global Forecasting API via tags only. However, a phased approach might be beneficial. If a tag-based approach is confirmed, we can merge GlobalBaseForecaster back into BaseForecaster after design validation.

Data Type Conversion Challenges

Data type conversion presents a challenge because PyTorch forecasting expects TimeSeriesDataSet as input. While TimeSeriesDataSet can be created from a pandas.DataFrame, it requires numerous parameters. Determining where to pass these parameters is a key question.

Placing them in fit would introduce inconsistency with the existing API. If we put them in __init__, it would be very counterintuitive to define how the data conversion works while initializing the algorithm.

A similar issue arises during trainer initialization. Currently, trainer_params: Dict[str, Any] is used within __init__ to obtain trainer initialization parameters. However, the API for passing these parameters to trainer.fit is yet to be designed.

To convert pytorch_forecasting.models.base_model.Prediction back to a pandas.DataFrame, a custom conversion method is required. Refer to the following issues for more information: jdb78/pytorch-forecasting#734, jdb78/pytorch-forecasting#177.

Train/Validation Strategy

Training a model in PyTorch forecasting necessitates passing both the training and validation datasets together to the training algorithm. This allows for monitoring training progress, adjusting the learning rate, saving the model, or even stopping training prematurely. This differs from the typical sktime approach where only the training data is passed to fit and the test data is used for validation after training. Any suggestions on how to best address this discrepancy?

@benHeid @fkiraly Thank you very much for the feedback on my GSoC proposal! Any suggestions on implementation details or the overall design would be greatly appreciated.

@Xinyu-Wu-0000 Xinyu-Wu-0000 changed the title [ENH] Global pytorch-forecasting [ENH] pytorch forecasting adapter with Global Forecasting API Mar 28, 2024
@benHeid
Copy link
Contributor

benHeid commented Mar 30, 2024

Just a general comment. I would propose to split this into multiple PRs. This would make it easier to review. I would propose a PR for the pytorch-forecasting adapter (first PR) and a second PR that introduces the global forecasting.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just commenting - very interesting!

I suppoes a separate class will give us opportunity to develop this capability. Ultimately, we may decide to merge it into BaseForecaster, or not.

@Xinyu-Wu-0000
Copy link
Contributor Author

Just a general comment. I would propose to split this into multiple PRs. This would make it easier to review. I would propose a PR for the pytorch-forecasting adapter (first PR) and a second PR that introduces the global forecasting.

Yeah, splitting this into 2 PRs will make the workflow more clear. But I am taking pytorch-forecasting as a experiment to try the global forecasting api design, ultimately DL models will need gloabl forecastig api anyway, would it be more convenient to have a single PR?

@fkiraly
Copy link
Collaborator

fkiraly commented Mar 31, 2024

Good question, @Xinyu-Wu-0000 - as long as this is experimental, it's up tp you.

It's always good to have one or two examples working for new API development, so it makes sense ot have examples in the PR.

Although there might be substantial challenges in isolation coming from pytorch forecasting which do not have to do with the global forecasting extension (you already list many of the serious ones, e.g., loader, prediction object), so I wonder if there is a simpler example to develop around.

Either way, that's not a strict requirement, as long as you are working in an exploratory sense.

@fkiraly fkiraly added module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting enhancement Adding new functionality labels Mar 31, 2024
@Xinyu-Wu-0000
Copy link
Contributor Author

so I wonder if there is a simpler example to develop around.

Maybe NeuralForecast could be a simpler example as it's already been interfaced and all models from NeuralForecast are capable of global forecasting, but several PRs are currently working on NeuralForecast. I choose pytorch-forecasting to minimize the impact on existing code base as extending global forecasting API will be a quite big change.

It's always good to have one or two examples working for new API development, so it makes sense ot have examples in the PR.

I just made it work for an example from pytorch-forecasting. It is the first tutorial in the document of pytorch-forecasting.
The test script for the example I use:
test_script.txt

By the way, are we going to have a release with partial global forecasting API support? Something like version 0.30, only NeuralForecast models and pytorch-forecasting models with global forecasting API.

@fkiraly
Copy link
Collaborator

fkiraly commented Apr 1, 2024

By the way, are we going to have a release with partial global forecasting API support?

Yes, I think that's a valid upgrade plan, e.g., release first only some forecasters, and then later merge base classes if everything is robust.

It could be 0.29.0 even in theory, because we're not impacting existing classes with your plan.

@Xinyu-Wu-0000
Copy link
Contributor Author

Xinyu-Wu-0000 commented May 27, 2024

Another wired issue about pytorch-forecasting models. If the max_prediction_length==1 there will be an error with TFT, NHiTS, and DeepAR models.

As dicussed in the mentor meeting, a possilble workaround would be passing max_prediction_length=2 and then droping one prediction point.

Details of the error in colab notebook: https://colab.research.google.com/drive/16ft4Prqe1pEmLHgz1CcFnTZf4kA9kFX5?usp=sharing

import numpy as np
import pandas as pd
from lightning.pytorch import Trainer
from pytorch_forecasting import (
    DeepAR,
    NBeats,
    NHiTS,
    TemporalFusionTransformer,
    TimeSeriesDataSet,
  )

# model to test
model_class = NHiTS
model_class = TemporalFusionTransformer
model_class = DeepAR
#model_class = NBeats
# set the max_prediction_length parameter
max_prediction_length = 1

n_timeseries = 100
time_points = 100
data = pd.DataFrame(
    data={
        "target": np.random.rand(time_points * n_timeseries),
        "time_varying_known_real_1": np.random.rand(time_points * n_timeseries),
        "time_idx": np.tile(np.arange(time_points), n_timeseries),
        "group_id": np.repeat(np.arange(n_timeseries), time_points),
    }
)
print(data)

training_dataset = TimeSeriesDataSet(
    data=data,
    time_idx="time_idx",
    target="target",
    group_ids=["group_id"],
    time_varying_unknown_reals=["target"],
    time_varying_known_reals=(
        ["time_varying_known_real_1"] if model_class != NBeats else []
    ),
    max_prediction_length=max_prediction_length,
)

validation_dataset = TimeSeriesDataSet.from_dataset(
    training_dataset, data, stop_randomization=True, predict=True
)

training_data_loader = training_dataset.to_dataloader(train=True)

validation_data_loader = validation_dataset.to_dataloader(train=False)

forecaster = model_class.from_dataset(training_dataset, log_val_interval=1)

pytorch_trainer = Trainer(
    accelerator="cpu",
    max_epochs=3,
    min_epochs=2,
    limit_train_batches=10,
)

pytorch_trainer.fit(
    forecaster,
    train_dataloaders=training_data_loader,
    val_dataloaders=validation_data_loader,
)

@yarnabrina
Copy link
Collaborator

Apologies if I am overlooking something: the colab shows that error is coming from matplotlib. Does it fail becuase these models do not support max_prediction_length=1, or does it fail only when trying to plot etc.?

@Xinyu-Wu-0000
Copy link
Contributor Author

@yarnabrina Yes, you are right, the error is from plotting during fitting. It might be a bug with pytorch-forecasting and I have opened a issue on their repository jdb78/pytorch-forecasting#1571.

It is different for NHiTS though. Maybe NHiTS doesn't support max_prediction_length=1.

@fkiraly
Copy link
Collaborator

fkiraly commented May 31, 2024

Yes, you are right, the error is from plotting during fitting.

We should not be plotting while we are fitting. The upstream libraries may choose to do so, but it would be a bad design choice to do so by default. It should be turned off by default, and hopefully that will (or has already?) resolved the error.

@Xinyu-Wu-0000
Copy link
Contributor Author

hopefully that will (or has already?) resolved the error

The plotting is in the logging validataion step, Therefor I turn it off in the get_test_params function to pass the CI. However, if users try to log the validation step and set max_prediction_length=1 in the same time, they will still encounter this bug. The best hope is that the upstream library can solve the bug otherwise we may have to try @benHeid 's workaround, passing max_prediction_length=2 and then droping one prediction point.

@Xinyu-Wu-0000
Copy link
Contributor Author

@benHeid @fkiraly @yarnabrina I think it's ready for a review now. There are 2 new forecasters PytorchForecastingNBeats, PytorchForecastingTFT, the adapter class _PytorchForecastingAdapter, the global forecasting api class BaseGlobalForecaster and a test class TestAllGlobalForecasters. No error in pytest test_all_forecasters.py in my local environment now, hope we can merge it soon and add two more forecasters later.

@Xinyu-Wu-0000
Copy link
Contributor Author

Added a quick fix for a bug discovered by @benHeid during the mentoring meeting.

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 3, 2024

kindly ignore the pykalman related failures, this is unrelated and will be fixed by #6519.

Copy link
Contributor

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xinyu-Wu-0000 thank you for the PR. I think it is very nice. However, there are a few things we need to discuss. And I am also not completely done with my Review. But I hope that these comments help you to continue to work on this feature.

gf = self.get_tag(
"capability:global_forecasting", tag_value_default=False, raise_error=False
)
if gf is not True and y is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative: not gf and y is not None

self._trainer_params = _none_check(self.trainer_params, {})
import lightning.pytorch as pl

traner_instance = pl.Trainer(**self._trainer_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer_instance

self._trainer_params = _none_check(self.trainer_params, {})
import lightning.pytorch as pl

traner_instance = pl.Trainer(**self._trainer_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can you access self._y_trainer_params. In the init only self.y_trainer_params is set.

reference to self
"""
self._dataset_params = _none_check(self.dataset_params, {})
self._max_prediction_length = fh.to_relative(self.cutoff)[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative: max(fh.to_relative...)

_y, self._convert_to_series = _series_to_frame(y)
_X, _ = _series_to_frame(X)
# store the target column names and index names (probably [None])
# will be renamed !
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with this comment? This comment is confusing me.

y_test = y_train.iloc[:-max_prediction_length]
y_pred = estimator_instance.predict(fh, y=y_test)

# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a todo

)
y_pred = estimator_instance.predict(fh, y=y_test)

# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a todo

@@ -892,3 +898,186 @@ def test_fit_predict(self, estimator_instance, n_columns):
cutoff = get_cutoff(y_train, return_index=True)
_assert_correct_pred_time_index(y_pred.index, cutoff, fh)
_assert_correct_columns(y_pred, y_train)


class TestAllGlobalForecasters(TestAllObjects):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I am wrong. If I have seen it correctly, you passed a copy of y_train to the predict methods. I am wondering if we can make this more general:

  • E.g. renaming the columns of y_test so that they are different from y_train and check if the current column names are returned.
  • Provide a different y_test than the y_train to the test and check if the results of both are different

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fkiraly maybe you have a comment on that

.. [2] https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.html
""" # noqa: E501

_tags = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose some of these models are supporting probabilistic forecasting. Please check this. It is okay if you add the support later, but we should track this in this case via an issue

@@ -2540,6 +2540,132 @@ def _get_columns(self, method="predict", **kwargs):
BaseForecaster._init_dynamic_doc()


class BaseGlobalForecaster(BaseForecaster):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that the BaseGlobalForecaster should in general also be able to perform probabilistic forecasts. It is okay to add this not directly in this PR I suppose. But this requires tracking of this via an additional issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fkiraly any comments on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Adding new functionality module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting
Projects
Status: Under review
Development

Successfully merging this pull request may close these issues.

None yet

4 participants