Training a Neural ODE with Optax¶
We can use the parameter estimation functionality to fit a neural ODE to a time series data set.
In [1]:
Copied!
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.util.doc_util import notebook
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.util.doc_util import notebook
In [2]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [3]:
Copied!
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
To keep the problem nice and small, assume that the data set is a trigonometric function (which solve differential equations).
In [4]:
Copied!
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [5]:
Copied!
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters),
init,
grid=grid,
solver=solver,
ssm=ssm,
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None],
standard_deviation=observation_std,
posterior=sol.posterior,
ssm=sol.ssm,
)
return -1 * marginal_likelihood
return loss_fn
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters),
init,
grid=grid,
solver=solver,
ssm=ssm,
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None],
standard_deviation=observation_std,
posterior=sol.posterior,
ssm=sol.ssm,
)
return -1 * marginal_likelihood
return loss_fn
In [6]:
Copied!
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
Construct an MLP with tanh activation¶
Let's start with the example given in the implicit layers tutorial. The vector field is provided by DiffEqZoo.
In [7]:
Copied!
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
In [8]:
Copied!
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [9]:
Copied!
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
In [10]:
Copied!
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i + 1) * chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i + 1) * chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
Negative log-marginal-likelihood after 625/15625 steps: 2183.317
Negative log-marginal-likelihood after 1250/15625 steps: 2141.8765
Negative log-marginal-likelihood after 1875/15625 steps: 2116.9614
Negative log-marginal-likelihood after 2500/15625 steps: 2092.9226
Negative log-marginal-likelihood after 3125/15625 steps: 2060.6357
Negative log-marginal-likelihood after 3750/15625 steps: 2066.017
Negative log-marginal-likelihood after 4375/15625 steps: 2010.7517
Negative log-marginal-likelihood after 5000/15625 steps: 1971.7511
Negative log-marginal-likelihood after 5625/15625 steps: 1923.9945
Negative log-marginal-likelihood after 6250/15625 steps: 1833.253
Negative log-marginal-likelihood after 6875/15625 steps: 1670.5896
Negative log-marginal-likelihood after 7500/15625 steps: 1468.7997
Negative log-marginal-likelihood after 8125/15625 steps: 1128.946
Negative log-marginal-likelihood after 8750/15625 steps: 954.581
Negative log-marginal-likelihood after 9375/15625 steps: 1986.423
Negative log-marginal-likelihood after 10000/15625 steps: 1681.4988
Negative log-marginal-likelihood after 10625/15625 steps: 891.4211
Negative log-marginal-likelihood after 11250/15625 steps: 876.61035
Negative log-marginal-likelihood after 11875/15625 steps: 845.3154
Negative log-marginal-likelihood after 12500/15625 steps: 1739.3035
Negative log-marginal-likelihood after 13125/15625 steps: 1496.7318
Negative log-marginal-likelihood after 13750/15625 steps: 834.9556
Negative log-marginal-likelihood after 14375/15625 steps: 8293.418
Negative log-marginal-likelihood after 15000/15625 steps: 1465.749
Negative log-marginal-likelihood after 15625/15625 steps: 911.4066
In [11]:
Copied!
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial guess")
plt.legend()
plt.show()
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial guess")
plt.legend()
plt.show()