Skip to content

Commit

Permalink
Merge pull request #166 from DiffEqML/neuralsde
Browse files Browse the repository at this point in the history
Merging neuralsde branch to the master branch for a new feature.
  • Loading branch information
JuanKo96 committed Nov 26, 2023
2 parents c708c4c + 0bbdd3e commit 1e8979a
Show file tree
Hide file tree
Showing 9 changed files with 1,085 additions and 156 deletions.
120 changes: 120 additions & 0 deletions test/test_sdeint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pytest
from torch import nn
import torch
import torchsde
import numpy as np
from torchdyn.numerics import sdeint
from numpy.testing import assert_almost_equal


@pytest.mark.parametrize("solver", ["euler", "milstein_ito"])
def test_geo_brownian_ito(solver):
torch.manual_seed(0)
np.random.seed(0)

t0, t1 = 0, 1
size = (1, 1)
device = "cpu"

alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device)
t_size = 1000
ts = torch.linspace(t0, t1, t_size).to(device)

bm = torchsde.BrownianInterval(
t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time"
)

def get_bm_queries(bm, ts):
bm_increments = torch.stack(
[bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0
)
bm_queries = torch.cat(
(torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0))
)
return bm_queries

class SDE(nn.Module):
def __init__(self, alpha, beta):
super().__init__()
self.alpha = nn.Parameter(alpha, requires_grad=True)
self.beta = nn.Parameter(beta, requires_grad=True)
self.noise_type = "diagonal"
self.sde_type = "ito"

def f(self, t, x):
return self.alpha * x

def g(self, t, x):
return self.beta * x

sde = SDE(alpha, beta).to(device)

with torch.no_grad():
_, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm)

bm_queries = get_bm_queries(bm, ts)
xs_true = x0.cpu() * np.exp(
(alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu()
+ beta.cpu() * bm_queries[:, 0, 0].cpu()
)

assert_almost_equal(xs_true[0][-1], xs_torchdyn[-1], decimal=2)


@pytest.mark.parametrize("solver", ["eulerHeun", "milstein_stratonovich"])
def test_geo_brownian_stratonovich(solver):
torch.manual_seed(0)
np.random.seed(0)

t0, t1 = 0, 1
size = (1, 1)
device = "cpu"

alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device)
t_size = 1000
ts = torch.linspace(t0, t1, t_size).to(device)

bm = torchsde.BrownianInterval(
t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time"
)

def get_bm_queries(bm, ts):
bm_increments = torch.stack(
[bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0
)
bm_queries = torch.cat(
(torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0))
)
return bm_queries

class SDE(nn.Module):
def __init__(self, alpha, beta):
super().__init__()
self.alpha = nn.Parameter(alpha, requires_grad=True)
self.beta = nn.Parameter(beta, requires_grad=True)
self.noise_type = "diagonal"
self.sde_type = "stratonovich"

def f(self, t, x):
return self.alpha * x

def g(self, t, x):
return self.beta * x

sde = SDE(alpha, beta).to(device)

with torch.no_grad():
_, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm)

bm_queries = get_bm_queries(bm, ts)
xs_true = x0.cpu() * np.exp(
(alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu()
+ beta.cpu() * bm_queries[:, 0, 0].cpu()
)

assert_almost_equal(xs_true[0][-1] - xs_torchdyn[-1], 1, decimal=0)

4 changes: 2 additions & 2 deletions torchdyn/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torchdyn.core.defunc import DEFunc
from torchdyn.core.defunc import DEFunc, SDEFunc
from torchdyn.core.neuralde import NeuralODE, NeuralSDE, MultipleShootingLayer
from torchdyn.core.problems import ODEProblem, SDEProblem, MultipleShootingProblem

# backward-compatibility (pre v0.2.0)
NeuralDE = NeuralODE

__all__ = ['DEFunc', 'NeuralODE', 'NeuralDE', 'NeuralSDE', 'ODEProblem', 'SDEProblem',
__all__ = ['DEFunc', 'SDEFunc', 'NeuralODE', 'NeuralDE', 'NeuralSDE', 'ODEProblem', 'SDEProblem',
'MultipleShootingProblem', 'MultipleShootingLayer']
87 changes: 56 additions & 31 deletions torchdyn/core/defunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import getfullargspec
from typing import Callable, Dict
import torch
from torch import Tensor, cat
import torch.nn as nn


class DEFuncBase(nn.Module):
def __init__(self, vector_field:Callable, has_time_arg:bool=True):
def __init__(self, vector_field: Callable, has_time_arg: bool = True):
"""Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
Args:
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
or `forward` method. Defaults to True.
"""
super().__init__()
self.nfe, self.vf, self.has_time_arg = 0., vector_field, has_time_arg
self.nfe, self.vf, self.has_time_arg = 0.0, vector_field, has_time_arg

def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
self.nfe += 1
if self.has_time_arg: return self.vf(t, x, args=args)
else: return self.vf(x)
if self.has_time_arg:
return self.vf(t, x, args=args)
else:
return self.vf(x)


class DEFunc(nn.Module):
def __init__(self, vector_field:Callable, order:int=1):
def __init__(self, vector_field: Callable, order: int = 1):
"""Special vector field wrapper for Neural ODEs.
Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.
Expand All @@ -51,43 +53,50 @@ def __init__(self, vector_field:Callable, order:int=1):
(3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
"""
super().__init__()
self.vf, self.nfe, = vector_field, 0.
self.vf, self.nfe, = vector_field, 0.0
self.order, self.integral_loss, self.sensitivity = order, None, None
# identify whether vector field already has time arg

def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
self.nfe += 1
# set `t` depth-variable to DepthCat modules
for _, module in self.vf.named_modules():
if hasattr(module, 't'):
if hasattr(module, "t"):
module.t = t

# if-else to handle autograd training with integral loss propagated in x[:, 0]
if (self.integral_loss is not None) and self.sensitivity == 'autograd':
if (self.integral_loss is not None) and self.sensitivity == "autograd":
x_dyn = x[:, 1:]
dlds = self.integral_loss(t, x_dyn)
if len(dlds.shape) == 1: dlds = dlds[:, None]
if self.order > 1: x_dyn = self.horder_forward(t, x_dyn, args)
else: x_dyn = self.vf(t, x_dyn)
if len(dlds.shape) == 1:
dlds = dlds[:, None]
if self.order > 1:
x_dyn = self.horder_forward(t, x_dyn, args)
else:
x_dyn = self.vf(t, x_dyn)
return cat([dlds, x_dyn], 1).to(x_dyn)

# regular forward
else:
if self.order > 1: x = self.higher_order_forward(t, x)
else: x = self.vf(t, x, args=args)
if self.order > 1:
x = self.higher_order_forward(t, x)
else:
x = self.vf(t, x, args=args)
return x

def higher_order_forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def higher_order_forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
x_new = []
size_order = x.size(1) // self.order
for i in range(1, self.order):
x_new.append(x[:, size_order*i : size_order*(i+1)])
x_new.append(x[:, size_order * i : size_order * (i + 1)])
x_new.append(self.vf(t, x))
return cat(x_new, dim=1).to(x)


class SDEFunc(nn.Module):
def __init__(self, f:Callable, g:Callable, order:int=1):
def __init__(
self, f: Callable, g: Callable, order: int = 1, noise_type=None, sde_type=None
):
""""Special vector field wrapper for Neural SDEs.
Args:
Expand All @@ -99,19 +108,35 @@ def __init__(self, f:Callable, g:Callable, order:int=1):
self.order, self.intloss, self.sensitivity = order, None, None
self.f_func, self.g_func = f, g
self.nfe = 0
self.noise_type = noise_type
self.sde_type = sde_type

def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
pass
def forward(self, t: Tensor, x: Tensor) -> Tensor:
raise NotImplementedError("Hopefully soon...")

def f(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def f(self, t: Tensor, x: Tensor) -> Tensor:
self.nfe += 1
for _, module in self.f_func.named_modules():
if hasattr(module, 't'):
module.t = t
return self.f_func(x, args)
if issubclass(type(self.f_func), nn.Module):
if "t" not in getfullargspec(self.f_func.forward).args:
return self.f_func(x)
else:
return self.f_func(t, x)
else:
if "t" not in getfullargspec(self.f_func).args:
return self.f_func(x)
else:
return self.f_func(t, x)

def g(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
for _, module in self.g_func.named_modules():
if hasattr(module, 't'):
module.t = t
return self.g_func(x, args)
def g(self, t: Tensor, x: Tensor) -> Tensor:
self.nfe += 1
if issubclass(type(self.g_func), nn.Module):

if "t" not in getfullargspec(self.g_func.forward).args:
return self.g_func(x)
else:
return self.g_func(t, x)
else:
if "t" not in getfullargspec(self.g_func).args:
return self.g_func(x)
else:
return self.g_func(t, x)

0 comments on commit 1e8979a

Please sign in to comment.