Skip to content

Commit

Permalink
Merge pull request #165 from DiffEqML/error_control
Browse files Browse the repository at this point in the history
Error control
  • Loading branch information
Zymrael committed Aug 22, 2022
2 parents 7a1080a + 6906258 commit 310c375
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ build-backend = "poetry.masonry.api"
requires = ["poetry", "wheel", "setuptools-cpp"]

[tool.pytest.ini_options]
log_cli = true
log_cli_level = "CRITICAL"
log_cli_format = "%(message)s"

log_file = "pytest.log"
log_file_level = "DEBUG"
log_file_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_file_date_format = "%Y-%m-%d %H:%M:%S"

filterwarnings = [
"ignore:Call to deprecated create function FieldDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
"ignore:Call to deprecated create function Descriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
Expand Down
14 changes: 12 additions & 2 deletions test/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Test adjoint and perform a rough benchmarking of wall-clock time

import time
from copy import deepcopy
import logging

import pytest
import torch
Expand All @@ -27,9 +30,10 @@
batch_size = 128
torch.manual_seed(1415112413244349)


t_span = torch.linspace(0, 1, 100)

logger = logging.getLogger("out")


# TODO(numerics): log wall-clock times and other torch.grad tests
# TODO(bug): `tsit5` + `adjoint` peak error
Expand All @@ -38,17 +42,21 @@
@pytest.mark.parametrize('stiffness', [0.1, 0.5])
@pytest.mark.parametrize('interpolator', [None])
def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):

f = VanDerPol(stiffness)
x = torch.randn(1024, 2, requires_grad=True)
t0 = time.time()

prob = ODEProblem(f, sensitivity=sensitivity, interpolator=interpolator, solver=solver, atol=1e-4, rtol=1e-4, atol_adjoint=1e-4, rtol_adjoint=1e-4)
t0 = time.time()
t_eval, sol_torchdyn = prob.odeint(x, t_span)
t_end1 = time.time() - t0

t0 = time.time()
sol_torchdiffeq = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
t_end2 = time.time() - t0

logger.info(f"Fwd times: {t_end1:.3f}, {t_end2:.3f}")

true_sol = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-9, rtol=1e-9)

t0 = time.time()
Expand All @@ -59,6 +67,8 @@ def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):
grad2 = torch.autograd.grad(sol_torchdiffeq[-1].sum(), x)[0]
t_end2 = time.time() - t0

logger.info(f"Bwd times: {t_end1:.3f}, {t_end2:.3f}")

grad_true = torch.autograd.grad(true_sol[-1].sum(), x)[0]

err1 = (grad1-grad_true).abs().sum(1)
Expand Down
4 changes: 2 additions & 2 deletions torchdyn/numerics/solvers/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def step(self, f, x, t, dt, k1=None, args=None) -> Tuple:
k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
err = berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7
err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)


Expand All @@ -174,7 +174,7 @@ def step(self, f, x, t, dt, k1=None, args=None) -> Tuple:
k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
err = berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7
err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)


Expand Down

0 comments on commit 310c375

Please sign in to comment.