Skip to content

Commit

Permalink
Merge branch 'master' into neuralsde
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanKo96 committed Nov 18, 2023
2 parents 880e855 + c708c4c commit 0bbdd3e
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 63 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
Expand Down
84 changes: 59 additions & 25 deletions .github/workflows/os-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,72 @@ jobs:
strategy:
fail-fast: true
max-parallel: 15

matrix:
os: [ubuntu-18.04, ubuntu-20.04, ubuntu-22.04, macos-latest]
python-version: [3.7, 3.8, 3.9]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
torch-version: ["1.8.1", "1.9.1", "1.10.0", "1.11.0", "1.12.0", "1.13.1", "2.0.0"]
exclude:
# temporary exclusion until py3.8 deps work on Windows
- python-version: 3.7
os: windows-2019
- python-version: 3.8
os: windows-2019
# python >= 3.10 does not support pytorch < 1.11.0
- torch-version: "1.8.1"
python-version: "3.10"
- torch-version: "1.9.1"
python-version: "3.10"
- torch-version: "1.10.0"
python-version: "3.10"
# python >= 3.11 does not support pytorch < 1.13.0
- torch-version: "1.8.1"
python-version: "3.11"
- torch-version: "1.9.1"
python-version: "3.11"
- torch-version: "1.10.0"
python-version: "3.11"
- torch-version: "1.11.0"
python-version: "3.11"
- torch-version: "1.12.0"
python-version: "3.11"
- torch-version: "1.13.1"
python-version: "3.11"

defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v2
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies

- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true

- name: Load cached venv
id: cached-pip-wheels
uses: actions/cache@v3
with:
path: ~/.cache
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.torch-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies # hack for 🐛: don't let poetry try installing Torch https://github.com/pytorch/pytorch/issues/88049
run: |
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
source $HOME/.poetry/env
python -m pip install --upgrade --user pip
pip install pytest pytest-cov
poetry lock
poetry build
poetry install
poetry run pip install 'setuptools==59.5.0'
# pinning setuptools is a temporary fix to: pytorch-lightning 1.5.10 requires setuptools==59.5.0
# supposedly poetry allows pinning ver. of setuptools in pyproject.ml files but it is not working atm https://github.com/python-poetry/poetry/issues/4511

- name: Run tests
pip install pytest pytest-cov papermill poethepoet>=0.10.0
pip install torch==${{ matrix.torch-version }} pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
poetry install --only-root
poetry run pip install setuptools
- name: List dependencies
run: |
pip list
- name: Run pytest checks
run: |
source $HOME/.poetry/env
source $VENV
poetry run coverage run --source=torchdyn -m pytest
- name: Report coverage
uses: codecov/codecov-action@v3.1.1
10 changes: 4 additions & 6 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,24 @@ on:
types:
- created
- edited

jobs:
build:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.8

- name: Build
run: |
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
source $HOME/.poetry/env
curl -sSL https://install.python-poetry.org | python3 -
poetry lock
poetry build
- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
run: |
source $HOME/.poetry/env
poetry config pypi-token.pypi ${{ secrets.pypi_token }}
poetry publish
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Torchdyn is a PyTorch library dedicated to **numerical deep learning**: differen
[![Slack](https://img.shields.io/badge/slack-chat-blue.svg?logo=slack)](https://join.slack.com/t/diffeqml/shared_invite/zt-trwgahq8-zgDqFmwS2gHYX6hsRvwDvg)
[![codecov](https://codecov.io/gh/DiffEqML/torchdyn/branch/master/graph/badge.svg)](https://codecov.io/gh/DiffEqML/torchdyn)
[![Docs](https://img.shields.io/badge/docs-passing-green.svg?)](https://torchdyn.readthedocs.io/)
[![python_sup](https://img.shields.io/badge/python-3.7+-black.svg?)](https://www.python.org/downloads/release/python-370/)
[![python_sup](https://img.shields.io/badge/python-3.8+-black.svg?)](https://www.python.org/downloads/release/python-370/)

</div>

Expand Down
17 changes: 13 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "torchdyn"
version = "1.0.3"
version = "1.0.6"
license = "Apache License, Version 2.0"
description = "A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods."
authors = ["Michael Poli", "Stefano Massaroli", "DiffEqML"]
Expand All @@ -9,11 +9,11 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.7"
torch = "^1.8.1"
python = "^3.8"
torch = ">=1.8.1"
torchsde="*"
torchcde="^0.2.3"
sklearn = "*"
scikit-learn = "*"
pytorch-lightning = "*"
torchvision = "*"
scipy = "*"
Expand All @@ -38,6 +38,15 @@ build-backend = "poetry.masonry.api"
requires = ["poetry", "wheel", "setuptools-cpp"]

[tool.pytest.ini_options]
log_cli = true
log_cli_level = "CRITICAL"
log_cli_format = "%(message)s"

log_file = "pytest.log"
log_file_level = "DEBUG"
log_file_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_file_date_format = "%Y-%m-%d %H:%M:%S"

filterwarnings = [
"ignore:Call to deprecated create function FieldDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
"ignore:Call to deprecated create function Descriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

setup(
name="torchdyn",
version="1.0.3",
version="1.0.6",
author="Michael Poli and Stefano Massaroli",
description="PyTorch package for all things neural differential equations.",
url="https://github.com/DiffEqML/torchdyn",
install_requires=[
"torch>=1.6.0",
"torch>=1.8.1",
"pytorch-lightning>=0.8.4",
"matplotlib",
"scikit-learn",
Expand All @@ -31,4 +31,5 @@
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
],
packages=["torchdyn"],
)
61 changes: 59 additions & 2 deletions test/models/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from packaging.version import parse
import pytest
import torch
import torch.nn as nn
Expand All @@ -19,7 +20,7 @@
from torchdyn.datasets import ToyDataset
from torchdyn.core import NeuralODE
from torchdyn.nn import GalLinear, GalConv2d, DepthCat, Augmenter, DataControl
from torchdyn.numerics import odeint, Euler
from torchdyn.numerics import odeint, odeint_mshooting, Lorenz, Euler

from functools import partial
import copy
Expand Down Expand Up @@ -258,4 +259,60 @@ def forward(self, t, x, u, v, z, args={}):
t_eval, sol2 = odeprob(x0, t_span=torch.linspace(0, 5, 10))

assert (sol1==sol2).all()
grad(sol2.sum(), x0)
grad(sol2.sum(), x0)


@pytest.mark.skipif(parse(torch.__version__) < parse("1.11.0"),
reason="adjoint support added in torch 1.11.0")
def test_complex_ode():
"""Test odeint for complex numbers with a simple complex-valued ODE, corresponding
to Rabi oscillations of quantum two-level system."""
class Rabi(nn.Module):
def __init__(self, omega):
super().__init__()
self.sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
self.omega = omega
return
def forward(self, t, x):
dx = -1.0j * self.omega * self.sx @ x
dx += dx.adjoint()
return dx

# Odeint parameters
omega = torch.randn(1)
rabi = Rabi(omega)
tspan = torch.linspace(0., 2., 10)

# Random initial state
x0 = torch.rand(2, 2, dtype=torch.complex128)
x0 = 0.5 * (x0 + x0.adjoint()) / torch.real(x0.trace())
# Solve the ODE problem
t_eval, sol = odeint(f=rabi, x=x0, t_span=tspan, solver="dopri5", atol=1e-8, rtol=1e-6)

# Expected solution
sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
si = torch.tensor([[1, 0], [0, 1]], dtype=torch.complex128)
U_t = torch.cos(omega * t_eval)[:, None, None] * si
U_t += -1j * torch.sin(omega * t_eval)[:, None, None] * sx
sol_exp = U_t @ x0 @ U_t.adjoint()

# Check result
assert torch.allclose(sol, sol_exp, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize('solver', ['mszero'])
def test_odeint_mshooting(solver):
x0 = torch.randn(8, 3) + 15
t_span = torch.linspace(0, 3, 10)
sys = Lorenz()

odeint_mshooting(sys, x0, t_span, solver=solver, fine_steps=2, maxiter=4)


@pytest.mark.parametrize('solver', ['euler', 'rk4', 'dopri5'])
def test_odeint(solver):
x0 = torch.randn(8, 3) + 15
t_span = torch.linspace(0., 2., 10)
sys = Lorenz()

odeint(sys, x0, t_span, solver=solver)
14 changes: 12 additions & 2 deletions test/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Test adjoint and perform a rough benchmarking of wall-clock time

import time
from copy import deepcopy
import logging

import pytest
import torch
Expand All @@ -27,9 +30,10 @@
batch_size = 128
torch.manual_seed(1415112413244349)


t_span = torch.linspace(0, 1, 100)

logger = logging.getLogger("out")


# TODO(numerics): log wall-clock times and other torch.grad tests
# TODO(bug): `tsit5` + `adjoint` peak error
Expand All @@ -38,17 +42,21 @@
@pytest.mark.parametrize('stiffness', [0.1, 0.5])
@pytest.mark.parametrize('interpolator', [None])
def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):

f = VanDerPol(stiffness)
x = torch.randn(1024, 2, requires_grad=True)
t0 = time.time()

prob = ODEProblem(f, sensitivity=sensitivity, interpolator=interpolator, solver=solver, atol=1e-4, rtol=1e-4, atol_adjoint=1e-4, rtol_adjoint=1e-4)
t0 = time.time()
t_eval, sol_torchdyn = prob.odeint(x, t_span)
t_end1 = time.time() - t0

t0 = time.time()
sol_torchdiffeq = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
t_end2 = time.time() - t0

logger.info(f"Fwd times: {t_end1:.3f}, {t_end2:.3f}")

true_sol = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-9, rtol=1e-9)

t0 = time.time()
Expand All @@ -59,6 +67,8 @@ def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):
grad2 = torch.autograd.grad(sol_torchdiffeq[-1].sum(), x)[0]
t_end2 = time.time() - t0

logger.info(f"Bwd times: {t_end1:.3f}, {t_end2:.3f}")

grad_true = torch.autograd.grad(true_sol[-1].sum(), x)[0]

err1 = (grad1-grad_true).abs().sum(1)
Expand Down
3 changes: 1 addition & 2 deletions torchdyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '1.0'
__version__ = '1.0.6'
__author__ = 'Michael Poli, Stefano Massaroli et al.'

from torch import Tensor
from typing import Tuple

TTuple = Tuple[Tensor, Tensor]

11 changes: 5 additions & 6 deletions torchdyn/numerics/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Functional API of ODE integration routines, with specialized functions for different options
`odeint` and `odeint_mshooting` prepare and redirect to more specialized routines, detected automatically.
"""
from inspect import getargspec
from typing import List, Tuple, Union, Callable, Dict, Iterable
from warnings import warn

Expand Down Expand Up @@ -65,11 +64,6 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
x, t_span = solver.sync_device_dtype(x, t_span)
stepping_class = solver.stepping_class

# instantiate save_at tensor
if len(save_at) == 0: save_at = t_span
if not isinstance(save_at, torch.Tensor):
save_at = torch.tensor(save_at)

# instantiate the interpolator similar to the solver steps above
if isinstance(solver, Tsitouras45):
if verbose: warn("Running interpolation not yet implemented for `tsit5`")
Expand All @@ -87,6 +81,7 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
if stepping_class == 'fixed':
if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]:
warn("Setting tolerances has no effect on fixed-step methods")
# instantiate save_at tensor
return _fixed_odeint(f_, x, t_span, solver, save_at=save_at, args=args)
elif stepping_class == 'adaptive':
t = t_span[0]
Expand Down Expand Up @@ -415,6 +410,10 @@ def _adaptive_odeint(f, k1, x, dt, t_span, solver, atol=1e-4, rtol=1e-4, args=No

def _fixed_odeint(f, x, t_span, solver, save_at=(), args={}):
"""Solves IVPs with same `t_span`, using fixed-step methods"""
if len(save_at) == 0: save_at = t_span
if not isinstance(save_at, torch.Tensor):
save_at = torch.tensor(save_at)

assert all(torch.isclose(t, save_at).sum() == 1 for t in save_at),\
"each element of save_at [torch.Tensor] must be contained in t_span [torch.Tensor] once and only once"

Expand Down

0 comments on commit 0bbdd3e

Please sign in to comment.