C0. Choose between prior distributions¶
Every probabilistic ODE solver is built on a Gauss Markov prior. Four prior types are compared here side by side: the integrated Wiener process (IWP), the integrated Ornstein-Uhlenbeck process (IOUP), a Matern 5/2 prior, and an oscillating prior. Each encodes a different assumption about the smoothness of the ODE solution.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
In [2]:
Copied!
from probdiffeq import probdiffeq
from probdiffeq.backend import func
from probdiffeq import probdiffeq
from probdiffeq.backend import func
In [3]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [4]:
Copied!
def main():
"""Sample from various prior distributions."""
ts = jnp.linspace(0.0, 5.0, num=100, endpoint=True)
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_matern(u, du, ddu, /):
"""Matern 5/2 prior."""
ell = 0.5
return -(ell**3) * u - 3 * ell**2 * du - 3 * ell * ddu
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_oscillator(_u, du, _ddu, /):
"""Oscillating prior."""
return -5 * du # always the second highest coefficient
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_ioup(_u, _du, ddu, /):
"""IOUP prior."""
return -5 * ddu # always the highest coefficient
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_iwp(u, _du, _ddu, /):
"""IWP prior."""
return 0.0 * u # always zeros
vf_titles = ["Oscillating", "Matern 5/2", "IOUP", "IWP"]
vf_functions = [vf_oscillator, vf_matern, vf_ioup, vf_iwp]
_fig, axes = plt.subplots(
nrows=3, ncols=4, sharex=True, figsize=(8, 5), constrained_layout=True
)
for i, (vf_prior, title, ax_col) in enumerate(zip(vf_functions, vf_titles, axes.T)):
# Match initial distribution to stationary distribution of Matern
mean, loc = [0.0, 0.0, 0.0], [2.5, 0.7, 0.6]
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_exponential_diffuse(vf_prior, mean, loc)
mseq = probdiffeq.MarkovSequence.from_grid(prior, grid=ts, reverse=False)
# Evaluate samples
key = jax.random.PRNGKey(i)
samples_prior = mseq.sample(key, shape=(3,))
# Evaluate marginals
margs = mseq.evaluate_marginals()
means = margs.mean
stds = margs.std
# Plot samples and marginals
ax_col[0].set_title(title, fontsize="medium")
for smp, m, std, ax in zip(samples_prior, means, stds, ax_col):
ax.plot(ts, smp.T, color=f"C{i}", linewidth=1.0)
ax.fill_between(ts, m - 2 * std, m + 2 * std, color=f"C{i}", alpha=0.25)
axes[0][0].set_ylabel("State", fontsize="medium")
axes[1][0].set_ylabel("Velocity", fontsize="medium")
axes[2][0].set_ylabel("Acceleration", fontsize="medium")
plt.show()
def main():
"""Sample from various prior distributions."""
ts = jnp.linspace(0.0, 5.0, num=100, endpoint=True)
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_matern(u, du, ddu, /):
"""Matern 5/2 prior."""
ell = 0.5
return -(ell**3) * u - 3 * ell**2 * du - 3 * ell * ddu
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_oscillator(_u, du, _ddu, /):
"""Oscillating prior."""
return -5 * du # always the second highest coefficient
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_ioup(_u, _du, ddu, /):
"""IOUP prior."""
return -5 * ddu # always the highest coefficient
@func.partial(probdiffeq.ode_autonomous_order_arbitrary, num_tcoeffs_in_args=3)
def vf_iwp(u, _du, _ddu, /):
"""IWP prior."""
return 0.0 * u # always zeros
vf_titles = ["Oscillating", "Matern 5/2", "IOUP", "IWP"]
vf_functions = [vf_oscillator, vf_matern, vf_ioup, vf_iwp]
_fig, axes = plt.subplots(
nrows=3, ncols=4, sharex=True, figsize=(8, 5), constrained_layout=True
)
for i, (vf_prior, title, ax_col) in enumerate(zip(vf_functions, vf_titles, axes.T)):
# Match initial distribution to stationary distribution of Matern
mean, loc = [0.0, 0.0, 0.0], [2.5, 0.7, 0.6]
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_exponential_diffuse(vf_prior, mean, loc)
mseq = probdiffeq.MarkovSequence.from_grid(prior, grid=ts, reverse=False)
# Evaluate samples
key = jax.random.PRNGKey(i)
samples_prior = mseq.sample(key, shape=(3,))
# Evaluate marginals
margs = mseq.evaluate_marginals()
means = margs.mean
stds = margs.std
# Plot samples and marginals
ax_col[0].set_title(title, fontsize="medium")
for smp, m, std, ax in zip(samples_prior, means, stds, ax_col):
ax.plot(ts, smp.T, color=f"C{i}", linewidth=1.0)
ax.fill_between(ts, m - 2 * std, m + 2 * std, color=f"C{i}", alpha=0.25)
axes[0][0].set_ylabel("State", fontsize="medium")
axes[1][0].set_ylabel("Velocity", fontsize="medium")
axes[2][0].set_ylabel("Acceleration", fontsize="medium")
plt.show()
In [5]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()