Train a Neural ODE with Optax¶
We can use the parameter estimation functionality to fit a neural ODE to a time series data set.
In [1]:
Copied!
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
"""Train a neural ODE with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
In [2]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [3]:
Copied!
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
# Catch NaN gradients in CI
# Disable to improve speed
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_platform_name", "cpu")
In [4]:
Copied!
impl.select("isotropic", ode_shape=(1,))
impl.select("isotropic", ode_shape=(1,))
To keep the problem nice and small, assume that the data set is a trigonometric function (which solve differential equations).
In [5]:
Copied!
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
grid = jnp.linspace(0, 1, num=100)
data = jnp.sin(5 * jnp.pi * grid)
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [6]:
Copied!
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
init = solver.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters), init, grid=grid, solver=solver
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None], standard_deviation=observation_std, posterior=sol.posterior
)
return -1 * marginal_likelihood
return loss_fn
def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2):
"""Build a loss function from an ODE problem."""
@jax.jit
def loss_fn(parameters):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
init = solver.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=parameters), init, grid=grid, solver=solver
)
observation_std = jnp.ones_like(grid) * standard_deviation
marginal_likelihood = stats.log_marginal_likelihood(
data[:, None], standard_deviation=observation_std, posterior=sol.posterior
)
return -1 * marginal_likelihood
return loss_fn
In [7]:
Copied!
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
Construct an MLP with tanh activation¶
Let's start with the example given in the implicit layers tutorial. The vector field is provided by DiffEqZoo.
In [8]:
Copied!
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
ibm = ivpsolvers.prior_ibm(num_derivatives=1)
ts0 = ivpsolvers.correction_ts0()
strategy = ivpsolvers.strategy_smoother(ibm, ts0)
solver_ts0 = ivpsolvers.solver(strategy)
f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1))
@jax.jit
def vf(y, *, t, p):
"""Evaluate the MLP."""
return f(y, t, *p)
# Make a solver
ibm = ivpsolvers.prior_ibm(num_derivatives=1)
ts0 = ivpsolvers.correction_ts0()
strategy = ivpsolvers.strategy_smoother(ibm, ts0)
solver_ts0 = ivpsolvers.solver(strategy)
In [9]:
Copied!
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial estimate")
plt.plot(grid, data, ".-", label="Data")
plt.legend()
plt.show()
In [10]:
Copied!
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0)
optim = optax.adam(learning_rate=2e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn)
In [11]:
Copied!
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i+1)*chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
p = f_args
state = optim.init(p)
chunk_size = 25
for i in range(chunk_size):
for _ in range(chunk_size**2):
p, state = update_fn(p, state)
print(
"Negative log-marginal-likelihood after "
f"{(i+1)*chunk_size**2}/{chunk_size**3} steps:",
loss_fn(p),
)
Negative log-marginal-likelihood after 625/15625 steps: 1796.1017
Negative log-marginal-likelihood after 1250/15625 steps: 1481.6743
Negative log-marginal-likelihood after 1875/15625 steps: 1365.6925
Negative log-marginal-likelihood after 2500/15625 steps: 1284.5906
Negative log-marginal-likelihood after 3125/15625 steps: 1105.0892
Negative log-marginal-likelihood after 3750/15625 steps: 939.03406
Negative log-marginal-likelihood after 4375/15625 steps: 718.1144
Negative log-marginal-likelihood after 5000/15625 steps: 407.6635
Negative log-marginal-likelihood after 5625/15625 steps: 148.90996
Negative log-marginal-likelihood after 6250/15625 steps: 81.4354
Negative log-marginal-likelihood after 6875/15625 steps: 1523.6747
Negative log-marginal-likelihood after 7500/15625 steps: 57.227238
Negative log-marginal-likelihood after 8125/15625 steps: 43.76754
Negative log-marginal-likelihood after 8750/15625 steps: 33.84983
Negative log-marginal-likelihood after 9375/15625 steps: 25.840542
Negative log-marginal-likelihood after 10000/15625 steps: 11.386798
Negative log-marginal-likelihood after 10625/15625 steps: 2.0756648
Negative log-marginal-likelihood after 11250/15625 steps: -1.5871391
Negative log-marginal-likelihood after 11875/15625 steps: -2.573725
Negative log-marginal-likelihood after 12500/15625 steps: -3.028935
Negative log-marginal-likelihood after 13125/15625 steps: -3.2776465
Negative log-marginal-likelihood after 13750/15625 steps: -3.418831
Negative log-marginal-likelihood after 14375/15625 steps: -3.474486
Negative log-marginal-likelihood after 15000/15625 steps: -3.5027845
Negative log-marginal-likelihood after 15625/15625 steps: -3.5293877
In [12]:
Copied!
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=p))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial guess")
plt.legend()
plt.show()
plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data")
tcoeffs = (u0, vf(u0, t=t0, p=p))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Final guess")
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
init = solver_ts0.initial_condition(tcoeffs, output_scale=1.0)
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0
)
plt.plot(sol.t, sol.u, ".-", label="Initial guess")
plt.legend()
plt.show()