Train a Neural ODE with Optax¶
We can use the parameter estimation functionality to fit a neural ODE to a time series data set.
In [1]:
Copied!
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.config
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solution, uncalibrated
from probdiffeq.solvers.strategies import smoothers
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.util.doc_util import notebook
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.config
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solution, uncalibrated
from probdiffeq.solvers.strategies import smoothers
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.util.doc_util import notebook
In [2]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [3]:
Copied!
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
In [4]:
Copied!
impl.select("isotropic", ode_shape=(1,))
impl.select("isotropic", ode_shape=(1,))
To keep the problem nice and small, assume that the data set is a trigonometric function (which solve differential equations).
In [5]:
Copied!
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [6]:
Copied!
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
init = solver.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters), init, grid=grid, solver=solver
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = solution.log_marginal_likelihood(
data[:, None], standard_deviation=observation_std, posterior=sol.posterior
)
return -1 * marginal_likelihood
return loss_fn
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
init = solver.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters), init, grid=grid, solver=solver
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = solution.log_marginal_likelihood(
data[:, None], standard_deviation=observation_std, posterior=sol.posterior
)
return -1 * marginal_likelihood
return loss_fn
In [7]:
Copied!
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
Construct an MLP with tanh activation¶
Let's start with the example given in the implicit layers tutorial. The vector field is provided by DiffEqZoo.
In [8]:
Copied!
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts0()
strategy = smoothers.smoother_adaptive(ibm, ts0)
solver_ts0 = uncalibrated.solver(strategy)
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts0()
strategy = smoothers.smoother_adaptive(ibm, ts0)
solver_ts0 = uncalibrated.solver(strategy)
In [9]:
Copied!
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [10]:
Copied!
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
In [11]:
Copied!
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i+1)*chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i+1)*chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
Negative log-marginal-likelihood after 625/15625 steps: 1778.9758 Negative log-marginal-likelihood after 1250/15625 steps: 1484.3114 Negative log-marginal-likelihood after 1875/15625 steps: 1365.3481 Negative log-marginal-likelihood after 2500/15625 steps: 1258.8456 Negative log-marginal-likelihood after 3125/15625 steps: 1076.4878 Negative log-marginal-likelihood after 3750/15625 steps: 912.1649 Negative log-marginal-likelihood after 4375/15625 steps: 670.2355 Negative log-marginal-likelihood after 5000/15625 steps: 353.85977 Negative log-marginal-likelihood after 5625/15625 steps: 136.30069 Negative log-marginal-likelihood after 6250/15625 steps: 80.586105 Negative log-marginal-likelihood after 6875/15625 steps: 58.67221 Negative log-marginal-likelihood after 7500/15625 steps: 65.909454 Negative log-marginal-likelihood after 8125/15625 steps: 36.336376 Negative log-marginal-likelihood after 8750/15625 steps: 25.50941 Negative log-marginal-likelihood after 9375/15625 steps: 12.484666 Negative log-marginal-likelihood after 10000/15625 steps: 2.5704055 Negative log-marginal-likelihood after 10625/15625 steps: -0.93296164 Negative log-marginal-likelihood after 11250/15625 steps: -2.107785 Negative log-marginal-likelihood after 11875/15625 steps: -2.8029163 Negative log-marginal-likelihood after 12500/15625 steps: -3.1140344 Negative log-marginal-likelihood after 13125/15625 steps: -3.3220475 Negative log-marginal-likelihood after 13750/15625 steps: -3.4408896 Negative log-marginal-likelihood after 14375/15625 steps: -3.4877534 Negative log-marginal-likelihood after 15000/15625 steps: -3.516136 Negative log-marginal-likelihood after 15625/15625 steps: -3.5458817
In [12]:
Copied!
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=p))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial guess")
plt.legend()
plt.show()
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=p))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial guess")
plt.legend()
plt.show()