E3. Fit the Robertson DAE to data¶
The Robertson problem is a stiff differential-algebraic equation (DAE) whose solution components span many orders of magnitude. This example estimates the unknown initial conditions from synthetic observations by minimising the negative log-marginal-likelihood via gradient descent.
In [1]:
Copied!
import functools
import functools
In [2]:
Copied!
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
In [3]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
# Double precision because adaptive steps with stiff DAEs
jax.config.update("jax_enable_x64", True)
# Double precision because adaptive steps with stiff DAEs
jax.config.update("jax_enable_x64", True)
In [6]:
Copied!
# Make the prints more readable
jnp.set_printoptions(3)
# Make the prints more readable
jnp.set_printoptions(3)
In [7]:
Copied!
class SimplexTransform:
"""Coordinate transformation to make the optimisation problem well-posed."""
def __init__(self, scale):
# e.g. jnp.array([1., 1e-5, 1e-3])
self.scale = jnp.asarray(scale)
def latent_to_observed(self, u):
"""Unconstrained R^2 -> simplex R^3."""
u_full = jnp.append(u, 0.0)
u_full = u_full - jnp.max(u_full)
e = jnp.exp(u_full)
x = e / e.sum()
# Rescale back
x *= self.scale
return x / x.sum()
class SimplexTransform:
"""Coordinate transformation to make the optimisation problem well-posed."""
def __init__(self, scale):
# e.g. jnp.array([1., 1e-5, 1e-3])
self.scale = jnp.asarray(scale)
def latent_to_observed(self, u):
"""Unconstrained R^2 -> simplex R^3."""
u_full = jnp.append(u, 0.0)
u_full = u_full - jnp.max(u_full)
e = jnp.exp(u_full)
x = e / e.sum()
# Rescale back
x *= self.scale
return x / x.sum()
In [8]:
Copied!
def main(
t0=1e-6, t1=1e5, num_data=20, tol=1e-5, std_log=-1.0, seed=1, epochs=100
) -> None:
"""Run the script."""
@functools.partial(probdiffeq.residual_jet_lift, lift_by=2)
@probdiffeq.residual_velocity
def differential(u, du, /, *, t):
del t
return du[:2] - dynamics(u)
def dynamics(y):
k1, k2, k3 = 0.04, 3e7, 1e4
f0 = -k1 * y[0] + k3 * y[1] * y[2]
f1 = k1 * y[0] - k2 * y[1] ** 2 - k3 * y[1] * y[2]
return jnp.stack([f0, f1])
@functools.partial(probdiffeq.residual_jet_lift, lift_by=3)
@probdiffeq.residual_position
def algebraic(u, *, t):
del t
return u[0] + u[1] + u[2] - 1
residual = probdiffeq.residual_from_stack(differential, algebraic)
def while_loop(cond, body, init):
"""Evaluate a bounded while loop."""
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=256)
# This base scale is critical to Robertson, because
# the solutions live on vastly different scales
# (but don't vary much within these scales).
output_scale = jnp.asarray([0.8, 2e-05, 0.2])
trafo = SimplexTransform(output_scale)
# Linear spacing on a log-scale
save_at = 2.0 ** jnp.linspace(jnp.log2(t0), jnp.log2(t1), num=num_data)
solve = solver(residual, tol=tol, while_loop=while_loop, trafo=trafo)
# True condition
key = jax.random.PRNGKey(seed)
p_true = jax.random.uniform(key, shape=(2,))
# Initial guess: p0 ~ U(-5, 5)
key = jax.random.PRNGKey(seed + 1)
p_guess = jax.random.uniform(key, shape=(2,))
# Create data
solution_true = solve(p_true, save_at=save_at, output_scale=output_scale)
inputs = solution_true.t
labels = solution_true.u.mean[0]
# Build a loss
loss = loss_data_fit(solve, inputs=inputs, labels=labels)
value_and_grad = jax.jit(jax.value_and_grad(loss, has_aux=True))
# Initialise the optimiser
optim = optax.sgd(0.05)
opt_state = optim.init(p_guess)
(value, _), grad = value_and_grad(
p_guess, std_log=std_log, output_scale=output_scale
)
print("Value:", value)
print("Gradient:", grad)
for epoch in range(epochs):
# Compute the gradient
(value, _), grad = value_and_grad(
p_guess, std_log=std_log, output_scale=output_scale
)
# Optimiser step
updates, opt_state = optim.update(grad, opt_state)
p_guess = optax.apply_updates(p_guess, updates)
# Display the progress
if epoch % 10 == 0:
y_guess = trafo.latent_to_observed(p_guess)
y_true = trafo.latent_to_observed(p_true)
print(
f"Epoch={epoch:4d} /{epochs:4d}, value={value:3.3e}, estim={y_guess}, true={y_true}"
)
# For the CI: fail the notebook if the estimates are off
y_guess = trafo.latent_to_observed(p_guess)
y_true = trafo.latent_to_observed(p_true)
assert jnp.allclose(y_guess, y_true, atol=1e-4, rtol=1e-4)
# Compare estimated and true trajectories
solution_guess = solve(p_guess, save_at=save_at, output_scale=output_scale)
fig, ax = plt.subplots(ncols=2, figsize=(8, 3), constrained_layout=True)
ax[0].set_title("Robertson trajectory", fontsize="medium")
ax[0].set_xlabel("Time $t$", fontsize="medium")
ax[0].set_ylabel("State", fontsize="medium")
for k in range(3):
ax[0].semilogx(
save_at,
solution_true.u.mean[0][:, k],
color=f"C{k}",
label=f"True $y_{k + 1}$",
)
ax[0].semilogx(
save_at, solution_guess.u.mean[0][:, k], color=f"C{k}", linestyle="dashed"
)
ax[0].legend(fontsize="x-small")
ax[1].set_title("Absolute error", fontsize="medium")
ax[1].set_xlabel("Time $t$", fontsize="medium")
ax[1].set_ylabel("Error", fontsize="medium")
for k in range(3):
err = jnp.abs(solution_true.u.mean[0][:, k] - solution_guess.u.mean[0][:, k])
ax[1].loglog(save_at, err + 1e-16, color=f"C{k}")
fig.align_ylabels()
plt.show()
def main(
t0=1e-6, t1=1e5, num_data=20, tol=1e-5, std_log=-1.0, seed=1, epochs=100
) -> None:
"""Run the script."""
@functools.partial(probdiffeq.residual_jet_lift, lift_by=2)
@probdiffeq.residual_velocity
def differential(u, du, /, *, t):
del t
return du[:2] - dynamics(u)
def dynamics(y):
k1, k2, k3 = 0.04, 3e7, 1e4
f0 = -k1 * y[0] + k3 * y[1] * y[2]
f1 = k1 * y[0] - k2 * y[1] ** 2 - k3 * y[1] * y[2]
return jnp.stack([f0, f1])
@functools.partial(probdiffeq.residual_jet_lift, lift_by=3)
@probdiffeq.residual_position
def algebraic(u, *, t):
del t
return u[0] + u[1] + u[2] - 1
residual = probdiffeq.residual_from_stack(differential, algebraic)
def while_loop(cond, body, init):
"""Evaluate a bounded while loop."""
return eqx.internal.while_loop(cond, body, init, kind="bounded", max_steps=256)
# This base scale is critical to Robertson, because
# the solutions live on vastly different scales
# (but don't vary much within these scales).
output_scale = jnp.asarray([0.8, 2e-05, 0.2])
trafo = SimplexTransform(output_scale)
# Linear spacing on a log-scale
save_at = 2.0 ** jnp.linspace(jnp.log2(t0), jnp.log2(t1), num=num_data)
solve = solver(residual, tol=tol, while_loop=while_loop, trafo=trafo)
# True condition
key = jax.random.PRNGKey(seed)
p_true = jax.random.uniform(key, shape=(2,))
# Initial guess: p0 ~ U(-5, 5)
key = jax.random.PRNGKey(seed + 1)
p_guess = jax.random.uniform(key, shape=(2,))
# Create data
solution_true = solve(p_true, save_at=save_at, output_scale=output_scale)
inputs = solution_true.t
labels = solution_true.u.mean[0]
# Build a loss
loss = loss_data_fit(solve, inputs=inputs, labels=labels)
value_and_grad = jax.jit(jax.value_and_grad(loss, has_aux=True))
# Initialise the optimiser
optim = optax.sgd(0.05)
opt_state = optim.init(p_guess)
(value, _), grad = value_and_grad(
p_guess, std_log=std_log, output_scale=output_scale
)
print("Value:", value)
print("Gradient:", grad)
for epoch in range(epochs):
# Compute the gradient
(value, _), grad = value_and_grad(
p_guess, std_log=std_log, output_scale=output_scale
)
# Optimiser step
updates, opt_state = optim.update(grad, opt_state)
p_guess = optax.apply_updates(p_guess, updates)
# Display the progress
if epoch % 10 == 0:
y_guess = trafo.latent_to_observed(p_guess)
y_true = trafo.latent_to_observed(p_true)
print(
f"Epoch={epoch:4d} /{epochs:4d}, value={value:3.3e}, estim={y_guess}, true={y_true}"
)
# For the CI: fail the notebook if the estimates are off
y_guess = trafo.latent_to_observed(p_guess)
y_true = trafo.latent_to_observed(p_true)
assert jnp.allclose(y_guess, y_true, atol=1e-4, rtol=1e-4)
# Compare estimated and true trajectories
solution_guess = solve(p_guess, save_at=save_at, output_scale=output_scale)
fig, ax = plt.subplots(ncols=2, figsize=(8, 3), constrained_layout=True)
ax[0].set_title("Robertson trajectory", fontsize="medium")
ax[0].set_xlabel("Time $t$", fontsize="medium")
ax[0].set_ylabel("State", fontsize="medium")
for k in range(3):
ax[0].semilogx(
save_at,
solution_true.u.mean[0][:, k],
color=f"C{k}",
label=f"True $y_{k + 1}$",
)
ax[0].semilogx(
save_at, solution_guess.u.mean[0][:, k], color=f"C{k}", linestyle="dashed"
)
ax[0].legend(fontsize="x-small")
ax[1].set_title("Absolute error", fontsize="medium")
ax[1].set_xlabel("Time $t$", fontsize="medium")
ax[1].set_ylabel("Error", fontsize="medium")
for k in range(3):
err = jnp.abs(solution_true.u.mean[0][:, k] - solution_guess.u.mean[0][:, k])
ax[1].loglog(save_at, err + 1e-16, color=f"C{k}")
fig.align_ylabels()
plt.show()
In [9]:
Copied!
def loss_data_fit(solve, *, inputs, labels):
"""Create a loss that measures the data fit."""
def loss(y0, std_log, output_scale):
std = jnp.exp(std_log) * output_scale
std_ts = jnp.ones_like(inputs)[:, None] * std[None, ...]
loss_lml = probdiffeq.loss_lml_timeseries()
sol = solve(y0, save_at=inputs, output_scale=output_scale)
lml = loss_lml(labels, std=std_ts, posterior=sol.solution_full)
return -lml, sol
return loss
def loss_data_fit(solve, *, inputs, labels):
"""Create a loss that measures the data fit."""
def loss(y0, std_log, output_scale):
std = jnp.exp(std_log) * output_scale
std_ts = jnp.ones_like(inputs)[:, None] * std[None, ...]
loss_lml = probdiffeq.loss_lml_timeseries()
sol = solve(y0, save_at=inputs, output_scale=output_scale)
lml = loss_lml(labels, std=std_ts, posterior=sol.solution_full)
return -lml, sol
return loss
In [10]:
Copied!
def solver(residual, tol, while_loop, trafo):
"""Create a reverse-mode differentiable probabilistic solver."""
@jax.jit
def solve(p_latent, save_at, output_scale):
y0 = trafo.latent_to_observed(p_latent)
t0, _t1 = save_at[0], save_at[-1]
nlstsq = probdiffeq.lstsq_constrained_gauss_newton(
maxiter=10, tol=tol, while_loop=while_loop
)
jetexpand = probdiffeq.jetexpand_residual(num=3, nlstsq=nlstsq)
tcoeffs, _ = jetexpand(residual, [y0], t=t0)
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_wiener_integrated(tcoeffs, output_scale=output_scale)
# We build a Jet constraint. Iteration is key, because DAEs are proper stiff.
taylor_point = probdiffeq.taylor_point_maximum_a_posteriori(nlstsq)
jet = ssm.constraint_residual(residual, taylor_point=taylor_point)
strategy = probdiffeq.strategy_smoother_fixedpoint()
solver_obj = probdiffeq.solver_dynamic(strategy=strategy, constraint=jet)
# The state-error-estimate doesn't care about the dimension
# of the DAE, which is exactly what we need here
error = probdiffeq.error_state_std(constraint=jet)
solve_fn = ivpsolve.solve_adaptive_save_at(
solver=solver_obj, error=error, while_loop=while_loop
)
return solve_fn(prior, save_at=save_at, atol=tol, rtol=tol)
return solve
def solver(residual, tol, while_loop, trafo):
"""Create a reverse-mode differentiable probabilistic solver."""
@jax.jit
def solve(p_latent, save_at, output_scale):
y0 = trafo.latent_to_observed(p_latent)
t0, _t1 = save_at[0], save_at[-1]
nlstsq = probdiffeq.lstsq_constrained_gauss_newton(
maxiter=10, tol=tol, while_loop=while_loop
)
jetexpand = probdiffeq.jetexpand_residual(num=3, nlstsq=nlstsq)
tcoeffs, _ = jetexpand(residual, [y0], t=t0)
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_wiener_integrated(tcoeffs, output_scale=output_scale)
# We build a Jet constraint. Iteration is key, because DAEs are proper stiff.
taylor_point = probdiffeq.taylor_point_maximum_a_posteriori(nlstsq)
jet = ssm.constraint_residual(residual, taylor_point=taylor_point)
strategy = probdiffeq.strategy_smoother_fixedpoint()
solver_obj = probdiffeq.solver_dynamic(strategy=strategy, constraint=jet)
# The state-error-estimate doesn't care about the dimension
# of the DAE, which is exactly what we need here
error = probdiffeq.error_state_std(constraint=jet)
solve_fn = ivpsolve.solve_adaptive_save_at(
solver=solver_obj, error=error, while_loop=while_loop
)
return solve_fn(prior, save_at=save_at, atol=tol, rtol=tol)
return solve
In [11]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
Value: -12.507131566846144 Gradient: [ 0.84 -0.715]
Epoch= 0 / 100, value=-1.251e+01, estim=[8.542e-01 1.752e-05 1.458e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 10 / 100, value=-1.289e+01, estim=[8.159e-01 2.722e-05 1.841e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 20 / 100, value=-1.289e+01, estim=[8.161e-01 2.826e-05 1.839e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 30 / 100, value=-1.290e+01, estim=[8.172e-01 2.846e-05 1.828e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 40 / 100, value=-1.290e+01, estim=[8.178e-01 2.854e-05 1.822e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 50 / 100, value=-1.290e+01, estim=[8.180e-01 2.858e-05 1.820e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 60 / 100, value=-1.290e+01, estim=[8.181e-01 2.860e-05 1.819e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 70 / 100, value=-1.290e+01, estim=[8.182e-01 2.861e-05 1.818e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 80 / 100, value=-1.290e+01, estim=[8.182e-01 2.861e-05 1.818e-01], true=[8.182e-01 2.861e-05 1.818e-01]
Epoch= 90 / 100, value=-1.290e+01, estim=[8.182e-01 2.861e-05 1.818e-01], true=[8.182e-01 2.861e-05 1.818e-01]