C4. Some solvers are more stable than others¶
IWP and IOUP priors combined with TS0 or TS1 constraints differ widely in stability. The IOUP prior and the TS1 constraint each improve stability, and their combination requires an order of magnitude fewer steps than IWP with TS0.
Source: Bosch, Hennig, Tronarp (2023), "Probabilistic exponential integrators", NeurIPS 36, 40450-40467.
In [1]:
Copied!
import functools
import functools
In [2]:
Copied!
import jax
import jax.experimental.ode
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import jax.experimental.ode
import jax.numpy as jnp
import matplotlib.pyplot as plt
In [3]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
def main():
"""Plot the solution of a semilinear ODE with different solvers and priors."""
@probdiffeq.ode
def vf(u, *, t):
"""Evaluate a linear vector field."""
del t
du1 = -0.5 * u[0] + 20 * u[1]
du2 = -20 * u[1]
return jnp.asarray([du1, du2])
u0 = jnp.asarray([0.0, 1.0])
t0, t1 = 0.0, 3.0
A = jnp.asarray([[-0.5, 20], [0, -20]])
# Set up a state-space model over Taylor coefficients
ssm = probdiffeq.state_space_model_dense()
# Build a solver
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=3)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
strategy = probdiffeq.strategy_smoother_fixedinterval()
ioup = ssm.prior_ornstein_uhlenbeck_integrated(lambda s: A @ s, tcoeffs)
ts0 = ssm.constraint_ode_ts0(vf)
ts1 = ssm.constraint_ode_ts1(vf)
# Prepare the plot
fig, axes = plt.subplots(ncols=4, figsize=(13, 3), constrained_layout=True)
solvers = [
("IWP + TS0 (300 steps)", iwp, ts0, 300),
("IOUP + TS0 (275 steps)", ioup, ts0, 275),
("IWP + TS1 (15 steps)", iwp, ts1, 15),
("IOUP + TS1 (6 steps)", ioup, ts1, 6),
]
for i, ((label, prior, constraint, num), ax) in enumerate(
zip(solvers, axes.flatten())
):
# Set up the solver and solve the ODE
solver = probdiffeq.solver_mle(strategy=strategy, constraint=constraint)
solve = ivpsolve.solve_fixed_grid(solver=solver)
grid = jnp.linspace(t0, t1, num=num, endpoint=True)
solution = jax.jit(solve)(prior, grid=grid)
# Calculate the solution at a finer grid for plotting
ts = jnp.linspace(t0 + 1e-4, t1 - 1e-4, num=200, endpoint=True)
dense = functools.partial(solver.offgrid_marginals, solution=solution)
u = jax.jit(jax.vmap(dense))(ts)
# Plot the solution
ax.set_title(label, fontsize="medium")
for d in (0, 1):
ax.plot(
solution.t,
solution.u.mean[0][:, d],
".",
alpha=0.75,
markerfacecolor=f"C{i}",
markeredgecolor="black",
)
m, s = u.mean[0][:, d], u.std[0][:, d]
ax.plot(ts, m, alpha=0.5, color=f"C{i}")
ax.fill_between(ts, m - s, m + s, alpha=0.25, color=f"C{i}")
# Set axis limits and labels
ax.set_xlim((t0 - 0.05, t1 + 0.05))
ax.set_ylim((-0.125, 1.125))
ax.set_xlabel("Time", fontsize="medium")
axes[0].set_ylabel("State", fontsize="medium")
fig.align_ylabels()
plt.show()
def main():
"""Plot the solution of a semilinear ODE with different solvers and priors."""
@probdiffeq.ode
def vf(u, *, t):
"""Evaluate a linear vector field."""
del t
du1 = -0.5 * u[0] + 20 * u[1]
du2 = -20 * u[1]
return jnp.asarray([du1, du2])
u0 = jnp.asarray([0.0, 1.0])
t0, t1 = 0.0, 3.0
A = jnp.asarray([[-0.5, 20], [0, -20]])
# Set up a state-space model over Taylor coefficients
ssm = probdiffeq.state_space_model_dense()
# Build a solver
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=3)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
strategy = probdiffeq.strategy_smoother_fixedinterval()
ioup = ssm.prior_ornstein_uhlenbeck_integrated(lambda s: A @ s, tcoeffs)
ts0 = ssm.constraint_ode_ts0(vf)
ts1 = ssm.constraint_ode_ts1(vf)
# Prepare the plot
fig, axes = plt.subplots(ncols=4, figsize=(13, 3), constrained_layout=True)
solvers = [
("IWP + TS0 (300 steps)", iwp, ts0, 300),
("IOUP + TS0 (275 steps)", ioup, ts0, 275),
("IWP + TS1 (15 steps)", iwp, ts1, 15),
("IOUP + TS1 (6 steps)", ioup, ts1, 6),
]
for i, ((label, prior, constraint, num), ax) in enumerate(
zip(solvers, axes.flatten())
):
# Set up the solver and solve the ODE
solver = probdiffeq.solver_mle(strategy=strategy, constraint=constraint)
solve = ivpsolve.solve_fixed_grid(solver=solver)
grid = jnp.linspace(t0, t1, num=num, endpoint=True)
solution = jax.jit(solve)(prior, grid=grid)
# Calculate the solution at a finer grid for plotting
ts = jnp.linspace(t0 + 1e-4, t1 - 1e-4, num=200, endpoint=True)
dense = functools.partial(solver.offgrid_marginals, solution=solution)
u = jax.jit(jax.vmap(dense))(ts)
# Plot the solution
ax.set_title(label, fontsize="medium")
for d in (0, 1):
ax.plot(
solution.t,
solution.u.mean[0][:, d],
".",
alpha=0.75,
markerfacecolor=f"C{i}",
markeredgecolor="black",
)
m, s = u.mean[0][:, d], u.std[0][:, d]
ax.plot(ts, m, alpha=0.5, color=f"C{i}")
ax.fill_between(ts, m - s, m + s, alpha=0.25, color=f"C{i}")
# Set axis limits and labels
ax.set_xlim((t0 - 0.05, t1 + 0.05))
ax.set_ylim((-0.125, 1.125))
ax.set_xlabel("Time", fontsize="medium")
axes[0].set_ylabel("State", fontsize="medium")
fig.align_ylabels()
plt.show()
In [6]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()