Training a Neural ODE with Optax¶
We can use the parameter estimation functionality to fit a neural ODE to a time series data set.
In [1]:
Copied!
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.util.doc_util import notebook
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.util.doc_util import notebook
In [2]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [3]:
Copied!
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
To keep the problem nice and small, assume that the data set is a trigonometric function (which solve differential equations).
In [4]:
Copied!
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [5]:
Copied!
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters),
init,
grid=grid,
solver=solver,
ssm=ssm,
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None],
standard_deviation=observation_std,
posterior=sol.posterior,
ssm=sol.ssm,
)
return -1 * marginal_likelihood
return loss_fn
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters),
init,
grid=grid,
solver=solver,
ssm=ssm,
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None],
standard_deviation=observation_std,
posterior=sol.posterior,
ssm=sol.ssm,
)
return -1 * marginal_likelihood
return loss_fn
In [6]:
Copied!
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
Construct an MLP with tanh activation¶
Let's start with the example given in the implicit layers tutorial. The vector field is provided by DiffEqZoo.
In [7]:
Copied!
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
In [8]:
Copied!
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [9]:
Copied!
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
In [10]:
Copied!
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i + 1) * chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i + 1) * chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
Negative log-marginal-likelihood after 625/15625 steps: 2183.3174
Negative log-marginal-likelihood after 1250/15625 steps: 2139.2583
Negative log-marginal-likelihood after 1875/15625 steps: 2119.1238
Negative log-marginal-likelihood after 2500/15625 steps: 2090.3723
Negative log-marginal-likelihood after 3125/15625 steps: 2055.1594
Negative log-marginal-likelihood after 3750/15625 steps: 2055.1846
Negative log-marginal-likelihood after 4375/15625 steps: 2044.1001
Negative log-marginal-likelihood after 5000/15625 steps: 1969.2017
Negative log-marginal-likelihood after 5625/15625 steps: 1911.1566
Negative log-marginal-likelihood after 6250/15625 steps: 1787.1368
Negative log-marginal-likelihood after 6875/15625 steps: 1582.8367
Negative log-marginal-likelihood after 7500/15625 steps: 1284.9406
Negative log-marginal-likelihood after 8125/15625 steps: 1040.6012
Negative log-marginal-likelihood after 8750/15625 steps: 935.1293
Negative log-marginal-likelihood after 9375/15625 steps: 784.8787
Negative log-marginal-likelihood after 10000/15625 steps: 614.761
Negative log-marginal-likelihood after 10625/15625 steps: 358.15936
Negative log-marginal-likelihood after 11250/15625 steps: 144.10838
Negative log-marginal-likelihood after 11875/15625 steps: 89.537964
Negative log-marginal-likelihood after 12500/15625 steps: 59.200634
Negative log-marginal-likelihood after 13125/15625 steps: 46.43938
Negative log-marginal-likelihood after 13750/15625 steps: 40.257816
Negative log-marginal-likelihood after 14375/15625 steps: 35.257412
Negative log-marginal-likelihood after 15000/15625 steps: 26.825794
Negative log-marginal-likelihood after 15625/15625 steps: 11.478086
In [11]:
Copied!
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial guess")
plt.legend()
plt.show()
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm
)
plt.plot(sol.t, sol.u[0], ".-", label="Initial guess")
plt.legend()
plt.show()