E0. Learn neural ODEs with diffusion tempering¶
A neural ODE uses a neural network as the vector field of a differential equation. Training minimises the negative log-marginal-likelihood of observed data under the ODE posterior, computed by a probabilistic ODE solver.
Diffusion tempering gradually reduces the output scale during training. This widens the posterior early on, helping escape local optima, and sharpens it as training progresses.
In [1]:
Copied!
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
In [2]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [3]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [4]:
Copied!
def main(num_data=20, epochs=1000, print_every=100, hidden=(20,), lr=0.2) -> None:
"""Train a neural ODE using diffusion tempering."""
# Create some data and construct a neural ODE
grid = jnp.linspace(0, 1, num=num_data)
data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid
std = 1e-1
output_scale = 1e1
vf, u0, (t0, _t1), params = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)
# Create a loss (this is where probabilistic numerics enters!)
loss = loss_log_marginal_likelihood(vf=vf, t0=t0)
# Evaluate once to get the mean before optimisation
_, info0 = loss(
params, u0=u0, grid=grid, data=data, std=std, output_scale=output_scale
)
# Construct an optimiser
optim = optax.adam(lr)
train_step = train_step_optax(optim, loss=loss)
# Train the model
print()
print("Loss after...")
state = optim.init(params)
for i in range(epochs):
(params, state), info = train_step(
params,
state,
u0=u0,
grid=grid,
data=data,
std=std,
output_scale=output_scale,
)
# Print progressbar
if i % print_every == print_every - 1:
print(f"...{(i + 1)} epochs: loss={info['loss']:.7e}")
# Diffusion tempering (https://arxiv.org/abs/2402.12231):
# scale down the output scale periodically to sharpen the posterior.
if i % 100 == 99:
output_scale /= 10.0
# Plot the results
plt.title(f"Final estimate | Loss: {info['loss']:.2f}")
plt.plot(grid, data, "x", label="Data", color="C0")
plt.plot(grid, info0["sol"].u.mean[0], "-", label="Initial estimate", color="C1")
plt.plot(grid, info["sol"].u.mean[0], "-", label="Final estimate", color="C2")
plt.legend()
plt.show()
def main(num_data=20, epochs=1000, print_every=100, hidden=(20,), lr=0.2) -> None:
"""Train a neural ODE using diffusion tempering."""
# Create some data and construct a neural ODE
grid = jnp.linspace(0, 1, num=num_data)
data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid
std = 1e-1
output_scale = 1e1
vf, u0, (t0, _t1), params = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)
# Create a loss (this is where probabilistic numerics enters!)
loss = loss_log_marginal_likelihood(vf=vf, t0=t0)
# Evaluate once to get the mean before optimisation
_, info0 = loss(
params, u0=u0, grid=grid, data=data, std=std, output_scale=output_scale
)
# Construct an optimiser
optim = optax.adam(lr)
train_step = train_step_optax(optim, loss=loss)
# Train the model
print()
print("Loss after...")
state = optim.init(params)
for i in range(epochs):
(params, state), info = train_step(
params,
state,
u0=u0,
grid=grid,
data=data,
std=std,
output_scale=output_scale,
)
# Print progressbar
if i % print_every == print_every - 1:
print(f"...{(i + 1)} epochs: loss={info['loss']:.7e}")
# Diffusion tempering (https://arxiv.org/abs/2402.12231):
# scale down the output scale periodically to sharpen the posterior.
if i % 100 == 99:
output_scale /= 10.0
# Plot the results
plt.title(f"Final estimate | Loss: {info['loss']:.2f}")
plt.plot(grid, data, "x", label="Data", color="C0")
plt.plot(grid, info0["sol"].u.mean[0], "-", label="Initial estimate", color="C1")
plt.plot(grid, info["sol"].u.mean[0], "-", label="Final estimate", color="C2")
plt.legend()
plt.show()
In [5]:
Copied!
def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
"""Build a neural ODE."""
params, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
u0 = jnp.asarray(0.0)
@jax.jit
def vf(y, /, *, t, p):
"""Evaluate the neural ODE vector field."""
y_and_t = jnp.concatenate([y[None], t[None]])
return mlp(p, y_and_t).reshape(())
return vf, (u0,), (t0, t1), params
def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
"""Build a neural ODE."""
params, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
u0 = jnp.asarray(0.0)
@jax.jit
def vf(y, /, *, t, p):
"""Evaluate the neural ODE vector field."""
y_and_t = jnp.concatenate([y[None], t[None]])
return mlp(p, y_and_t).reshape(())
return vf, (u0,), (t0, t1), params
In [6]:
Copied!
def model_mlp(
*, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
):
"""Construct an MLP."""
assert len(shape_in) <= 1
assert len(shape_out) <= 1
shape_prev = shape_in
weights = []
for h in hidden:
W = jnp.empty((h, *shape_prev))
b = jnp.empty((h,))
shape_prev = (h,)
weights.append((W, b))
W = jnp.empty((*shape_out, *shape_prev))
b = jnp.empty(shape_out)
weights.append((W, b))
p_flat, unravel = jax.flatten_util.ravel_pytree(weights)
def fwd(params, x):
for A, b in params[:-1]:
x = jnp.dot(A, x) + b
x = activation(x)
A, b = params[-1]
return jnp.dot(A, x) + b
key = jax.random.PRNGKey(1)
p_init = jax.random.normal(key, shape=p_flat.shape, dtype=p_flat.dtype)
return unravel(p_init), fwd
def model_mlp(
*, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
):
"""Construct an MLP."""
assert len(shape_in) <= 1
assert len(shape_out) <= 1
shape_prev = shape_in
weights = []
for h in hidden:
W = jnp.empty((h, *shape_prev))
b = jnp.empty((h,))
shape_prev = (h,)
weights.append((W, b))
W = jnp.empty((*shape_out, *shape_prev))
b = jnp.empty(shape_out)
weights.append((W, b))
p_flat, unravel = jax.flatten_util.ravel_pytree(weights)
def fwd(params, x):
for A, b in params[:-1]:
x = jnp.dot(A, x) + b
x = activation(x)
A, b = params[-1]
return jnp.dot(A, x) + b
key = jax.random.PRNGKey(1)
p_init = jax.random.normal(key, shape=p_flat.shape, dtype=p_flat.dtype)
return unravel(p_init), fwd
In [7]:
Copied!
def loss_log_marginal_likelihood(vf, *, t0):
"""Build a loss function from an ODE problem."""
ssm = probdiffeq.state_space_model_dense()
strategy = probdiffeq.strategy_smoother_fixedpoint()
def while_loop(cond, body, init):
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=8)
@jax.jit
def loss(
p: jax.Array,
*,
u0: tuple,
grid: jax.Array,
data: jax.Array,
std: jax.Array,
output_scale: jax.Array,
):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*u0, vf(*u0, t=t0, p=p))
iwp = ssm.prior_wiener_integrated(tcoeffs, output_scale=output_scale)
@probdiffeq.ode
def vf_p(y, /, *, t):
return vf(y, t=t, p=p)
ts0 = ssm.constraint_ode_ts0(vf_p)
solver_ts0 = probdiffeq.solver(strategy=strategy, constraint=ts0)
error = probdiffeq.error_state_std(constraint=ts0)
solve = ivpsolve.solve_adaptive_save_at(
solver=solver_ts0, error=error, while_loop=while_loop
)
sol = solve(iwp, save_at=grid, atol=1e-4, rtol=1e-2)
# Evaluate loss
loss_lml = probdiffeq.loss_lml_timeseries()
std_array = jnp.ones_like(grid) * std[None]
lml = loss_lml(data, std=std_array, posterior=sol.solution_full)
return -lml, {"sol": sol}
return loss
def loss_log_marginal_likelihood(vf, *, t0):
"""Build a loss function from an ODE problem."""
ssm = probdiffeq.state_space_model_dense()
strategy = probdiffeq.strategy_smoother_fixedpoint()
def while_loop(cond, body, init):
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=8)
@jax.jit
def loss(
p: jax.Array,
*,
u0: tuple,
grid: jax.Array,
data: jax.Array,
std: jax.Array,
output_scale: jax.Array,
):
"""Loss function: log-marginal likelihood of the data."""
tcoeffs = (*u0, vf(*u0, t=t0, p=p))
iwp = ssm.prior_wiener_integrated(tcoeffs, output_scale=output_scale)
@probdiffeq.ode
def vf_p(y, /, *, t):
return vf(y, t=t, p=p)
ts0 = ssm.constraint_ode_ts0(vf_p)
solver_ts0 = probdiffeq.solver(strategy=strategy, constraint=ts0)
error = probdiffeq.error_state_std(constraint=ts0)
solve = ivpsolve.solve_adaptive_save_at(
solver=solver_ts0, error=error, while_loop=while_loop
)
sol = solve(iwp, save_at=grid, atol=1e-4, rtol=1e-2)
# Evaluate loss
loss_lml = probdiffeq.loss_lml_timeseries()
std_array = jnp.ones_like(grid) * std[None]
lml = loss_lml(data, std=std_array, posterior=sol.solution_full)
return -lml, {"sol": sol}
return loss
In [8]:
Copied!
def train_step_optax(optimizer, loss):
"""Implement a training step using Optax."""
@jax.jit
def update(params, opt_state, **loss_kwargs):
"""Update the optimiser state."""
value_and_grad = jax.value_and_grad(loss, argnums=0, has_aux=True)
(value, info), grads = value_and_grad(params, **loss_kwargs)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return (params, opt_state), {"sol": info["sol"], "loss": value}
return update
def train_step_optax(optimizer, loss):
"""Implement a training step using Optax."""
@jax.jit
def update(params, opt_state, **loss_kwargs):
"""Update the optimiser state."""
value_and_grad = jax.value_and_grad(loss, argnums=0, has_aux=True)
(value, info), grads = value_and_grad(params, **loss_kwargs)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return (params, opt_state), {"sol": info["sol"], "loss": value}
return update
In [9]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
Loss after...
...100 epochs: loss=1.3396350e+01
...200 epochs: loss=4.7973553e+01
...300 epochs: loss=-4.2167085e-01
...400 epochs: loss=-5.0469583e-01
...500 epochs: loss=-2.7027121e-01
...600 epochs: loss=-1.1443175e+00
...700 epochs: loss=-1.2489402e+00
...800 epochs: loss=-8.4951007e-01
...900 epochs: loss=-1.3469598e+00
...1000 epochs: loss=-1.0835025e+00