E2. Estimate parameters (via Optax)¶
We create some data,
compute the marginal likelihood of this data under the ODE posterior
(which is something deterministic solvers cannot do),
and optimize the parameters with optax.
In [1]:
Copied!
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
In [2]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [3]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [4]:
Copied!
# Lotka-Volterra predator-prey model
u0 = jnp.asarray([20.0, 20.0])
t0, t1 = 0.0, 20.0
rate_constants = jnp.asarray([0.5, 0.05, 0.5, 0.05]) # (a, b, c, d)
# Lotka-Volterra predator-prey model
u0 = jnp.asarray([20.0, 20.0])
t0, t1 = 0.0, 20.0
rate_constants = jnp.asarray([0.5, 0.05, 0.5, 0.05]) # (a, b, c, d)
In [5]:
Copied!
def main():
"""Learn an ODE with Optax."""
# Define the problem
def vf(y, t, *, p): # noqa: ARG001
"""Evaluate the Lotka-Volterra vector field."""
a, b, c, d = p[0], p[1], p[2], p[3]
return jnp.asarray([a * y[0] - b * y[0] * y[1], -c * y[1] + d * y[0] * y[1]])
grid = jnp.linspace(t0, t1, endpoint=True, num=50)
solve = solver(vf, u0, grid=grid)
# Create a dataset
parameter_true = rate_constants + 0.05
parameter_guess = rate_constants
solution_true = solve(parameter_true)
data = solution_true.u.mean[0]
# We make an initial guess, but it does not lead to a good data fit:
initial = solve(parameter_guess)
# Use probdiffeq to form the loss function:
loss = loss_marginal_likelihood(solve=solve, data=data)
value_and_grad = jax.jit(jax.value_and_grad(loss))
# We can differentiate the function forward- and reverse-mode
print("Value and gradient:")
print(value_and_grad(parameter_guess))
# Enter Optax:
print()
print("Training:")
optim = optax.adam(learning_rate=1e-2)
update = build_update(optimizer=optim, value_and_grad=value_and_grad)
p = parameter_guess
state = optim.init(p)
for i in range(20):
for _ in range(20):
p, state = update(p, state)
print(f"After {(i + 1) * 20} iterations:", p)
# The solution looks much better:
final = solve(p)
fig, ax = plt.subplots(figsize=(5, 3), dpi=100, constrained_layout=True)
ax.set_title("Learning a Predator-Prey model", fontsize="medium")
ax.set_xlabel("Predators", fontsize="medium")
ax.set_ylabel("Prey", fontsize="medium")
ax.plot(
data[:, 0], data[:, 1], "X", markersize=8, label="Data", color="k", alpha=0.2
)
ax.plot(
initial.u.mean[0][:, 0],
initial.u.mean[0][:, 1],
color="C0",
label="Initial guess",
linestyle="dashed",
)
ax.plot(final.u.mean[0][:, 0], final.u.mean[0][:, 1], color="C1", label="Optimised")
ax.legend(fontsize="small")
fig.align_ylabels()
plt.show()
def main():
"""Learn an ODE with Optax."""
# Define the problem
def vf(y, t, *, p): # noqa: ARG001
"""Evaluate the Lotka-Volterra vector field."""
a, b, c, d = p[0], p[1], p[2], p[3]
return jnp.asarray([a * y[0] - b * y[0] * y[1], -c * y[1] + d * y[0] * y[1]])
grid = jnp.linspace(t0, t1, endpoint=True, num=50)
solve = solver(vf, u0, grid=grid)
# Create a dataset
parameter_true = rate_constants + 0.05
parameter_guess = rate_constants
solution_true = solve(parameter_true)
data = solution_true.u.mean[0]
# We make an initial guess, but it does not lead to a good data fit:
initial = solve(parameter_guess)
# Use probdiffeq to form the loss function:
loss = loss_marginal_likelihood(solve=solve, data=data)
value_and_grad = jax.jit(jax.value_and_grad(loss))
# We can differentiate the function forward- and reverse-mode
print("Value and gradient:")
print(value_and_grad(parameter_guess))
# Enter Optax:
print()
print("Training:")
optim = optax.adam(learning_rate=1e-2)
update = build_update(optimizer=optim, value_and_grad=value_and_grad)
p = parameter_guess
state = optim.init(p)
for i in range(20):
for _ in range(20):
p, state = update(p, state)
print(f"After {(i + 1) * 20} iterations:", p)
# The solution looks much better:
final = solve(p)
fig, ax = plt.subplots(figsize=(5, 3), dpi=100, constrained_layout=True)
ax.set_title("Learning a Predator-Prey model", fontsize="medium")
ax.set_xlabel("Predators", fontsize="medium")
ax.set_ylabel("Prey", fontsize="medium")
ax.plot(
data[:, 0], data[:, 1], "X", markersize=8, label="Data", color="k", alpha=0.2
)
ax.plot(
initial.u.mean[0][:, 0],
initial.u.mean[0][:, 1],
color="C0",
label="Initial guess",
linestyle="dashed",
)
ax.plot(final.u.mean[0][:, 0], final.u.mean[0][:, 1], color="C1", label="Optimised")
ax.legend(fontsize="small")
fig.align_ylabels()
plt.show()
In [6]:
Copied!
def solver(vf, u0, *, grid):
"""Construct a solver."""
ssm = probdiffeq.state_space_model_isotropic()
strategy = probdiffeq.strategy_smoother_fixedpoint()
def while_loop(cond, body, init):
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=8)
def solve(p):
"""Evaluate the parameter-to-solution map."""
tcoeffs = (u0, vf(u0, grid[0], p=p))
iwp = ssm.prior_wiener_integrated(tcoeffs, output_scale=10.0)
@probdiffeq.ode
def vf_p(y, /, *, t):
return vf(y, t=t, p=p)
ts0 = ssm.constraint_ode_ts0(vf_p)
solver_obj = probdiffeq.solver(strategy=strategy, constraint=ts0)
error = probdiffeq.error_state_std(constraint=ts0)
solve_fn = ivpsolve.solve_adaptive_save_at(
solver=solver_obj, error=error, while_loop=while_loop
)
return solve_fn(iwp, save_at=grid, atol=1e-4, rtol=1e-2)
return solve
def solver(vf, u0, *, grid):
"""Construct a solver."""
ssm = probdiffeq.state_space_model_isotropic()
strategy = probdiffeq.strategy_smoother_fixedpoint()
def while_loop(cond, body, init):
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=8)
def solve(p):
"""Evaluate the parameter-to-solution map."""
tcoeffs = (u0, vf(u0, grid[0], p=p))
iwp = ssm.prior_wiener_integrated(tcoeffs, output_scale=10.0)
@probdiffeq.ode
def vf_p(y, /, *, t):
return vf(y, t=t, p=p)
ts0 = ssm.constraint_ode_ts0(vf_p)
solver_obj = probdiffeq.solver(strategy=strategy, constraint=ts0)
error = probdiffeq.error_state_std(constraint=ts0)
solve_fn = ivpsolve.solve_adaptive_save_at(
solver=solver_obj, error=error, while_loop=while_loop
)
return solve_fn(iwp, save_at=grid, atol=1e-4, rtol=1e-2)
return solve
In [7]:
Copied!
def loss_marginal_likelihood(*, data, solve, std=1e-1):
"""Create a loss function."""
loss_lml = probdiffeq.loss_lml_timeseries()
@jax.jit
def loss(params, /):
"""Evaluate the data fit as a function of the parameters."""
sol = solve(params)
std_array = jnp.ones_like(sol.t) * std
lml = loss_lml(data, std=std_array, posterior=sol.solution_full)
return -lml
return loss
def loss_marginal_likelihood(*, data, solve, std=1e-1):
"""Create a loss function."""
loss_lml = probdiffeq.loss_lml_timeseries()
@jax.jit
def loss(params, /):
"""Evaluate the data fit as a function of the parameters."""
sol = solve(params)
std_array = jnp.ones_like(sol.t) * std
lml = loss_lml(data, std=std_array, posterior=sol.solution_full)
return -lml
return loss
In [8]:
Copied!
def build_update(*, optimizer, value_and_grad):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = value_and_grad(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
def build_update(*, optimizer, value_and_grad):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = value_and_grad(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
In [9]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
Value and gradient:
(Array(35.45809, dtype=float32), Array([ 5.19902 , 83.05482 , 4.140925, 124.85628 ], dtype=float32)) Training:
After 20 iterations:
[0.4778001 0.03612584 0.4701752 0.03569311] After 40 iterations:
[0.44278133 0.04669723 0.38757154 0.04004207] After 60 iterations:
[0.49936032 0.07502096 0.4525252 0.05878462] After 80 iterations:
[0.5196196 0.07860456 0.4732414 0.07294172] After 100 iterations:
[0.5392868 0.08796278 0.49480444 0.08228298] After 120 iterations:
[0.55533725 0.09451369 0.5122611 0.09091674] After 140 iterations:
[0.564423 0.09709816 0.5224094 0.09480488] After 160 iterations:
[0.5672446 0.09945247 0.5261267 0.09685365] After 180 iterations:
[0.5690989 0.10028666 0.5287324 0.09774701] After 200 iterations:
[0.5695098 0.10110597 0.52985746 0.09870844] After 220 iterations:
[0.56955487 0.10108737 0.5305448 0.09905498] After 240 iterations:
[0.5696292 0.10116171 0.53129464 0.09884549] After 260 iterations:
[0.5697413 0.10097475 0.5320343 0.09871338] After 280 iterations:
[0.56948555 0.10091075 0.53239506 0.09879951] After 300 iterations:
[0.56916684 0.10088488 0.53270334 0.09881116] After 320 iterations:
[0.5688465 0.1008419 0.53300077 0.09881564] After 340 iterations:
[0.568517 0.10080431 0.5332811 0.09883735] After 360 iterations:
[0.56819624 0.10078452 0.5335638 0.09885559] After 380 iterations:
[0.5678866 0.10076116 0.5338469 0.09886346] After 400 iterations:
[0.5675717 0.10073939 0.5341121 0.09886739]