Skip to content

Commit

Permalink
Merge pull request #201 from atong01/torch-2.0
Browse files Browse the repository at this point in the history
Torch 2.0 Compatibility
  • Loading branch information
Zymrael committed Aug 30, 2023
2 parents 92dd9ba + 2b0de91 commit d6772ab
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
33 changes: 30 additions & 3 deletions .github/workflows/os-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@ jobs:
max-parallel: 15
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ["3.8", "3.9", "3.10", "3.11"]
torch-version: ["1.8.1", "1.9.1", "1.10.0", "1.11.0", "1.12.0", "1.13.1", "2.0.0"]
exclude:
# python >= 3.10 does not support pytorch < 1.11.0
- torch-version: "1.8.1"
python-version: "3.10"
- torch-version: "1.9.1"
python-version: "3.10"
- torch-version: "1.10.0"
python-version: "3.10"
# python >= 3.11 does not support pytorch < 1.13.0
- torch-version: "1.8.1"
python-version: "3.11"
- torch-version: "1.9.1"
python-version: "3.11"
- torch-version: "1.10.0"
python-version: "3.11"
- torch-version: "1.11.0"
python-version: "3.11"
- torch-version: "1.12.0"
python-version: "3.11"
- torch-version: "1.13.1"
python-version: "3.11"

defaults:
run:
shell: bash
Expand All @@ -33,14 +56,18 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.torch-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies # hack for 🐛: don't let poetry try installing Torch https://github.com/pytorch/pytorch/issues/88049
run: |
pip install pytest pytest-cov papermill poethepoet>=0.10.0
pip install torch>=1.8.1 torchvision pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
pip install torch==${{ matrix.torch-version }} pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
poetry install --only-root
poetry run pip install setuptools
- name: List dependencies
run: |
pip list
- name: Run pytest checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.8"
torch = "^1.8.1"
torch = ">=1.8.1"
torchsde="*"
torchcde="^0.2.3"
scikit-learn = "*"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
description="PyTorch package for all things neural differential equations.",
url="https://github.com/DiffEqML/torchdyn",
install_requires=[
"torch>=1.6.0",
"torch>=1.8.1",
"pytorch-lightning>=0.8.4",
"matplotlib",
"scikit-learn",
Expand Down
5 changes: 4 additions & 1 deletion test/models/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from packaging.version import parse
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -261,6 +262,8 @@ def forward(self, t, x, u, v, z, args={}):
grad(sol2.sum(), x0)


@pytest.mark.skipif(parse(torch.__version__) < parse("1.11.0"),
reason="adjoint support added in torch 1.11.0")
def test_complex_ode():
"""Test odeint for complex numbers with a simple complex-valued ODE, corresponding
to Rabi oscillations of quantum two-level system."""
Expand Down Expand Up @@ -312,4 +315,4 @@ def test_odeint(solver):
t_span = torch.linspace(0., 2., 10)
sys = Lorenz()

odeint(sys, x0, t_span, solver=solver)
odeint(sys, x0, t_span, solver=solver)
3 changes: 1 addition & 2 deletions torchdyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '1.0'
__version__ = '1.0.5'
__author__ = 'Michael Poli, Stefano Massaroli et al.'

from torch import Tensor
from typing import Tuple

TTuple = Tuple[Tensor, Tensor]

0 comments on commit d6772ab

Please sign in to comment.