Understanding the role of Taylor coefficients¶
To build a probabilistic solver, we need to build a specific state-space model. To build this specific state-space model, we interact with Taylor coefficients. Here are some examples how Taylor coefficients play a role in Probdiffeq's solution routines.
In [1]:
Copied!
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.util.doc_util import notebook
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.util.doc_util import notebook
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
We start by defining an ODE.
In [2]:
Copied!
f, u0, (t0, t1), f_args = ivps.rigid_body()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
f, u0, (t0, t1), f_args = ivps.rigid_body()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
Here is a wrapper arounds Probdiffeq's solution routine.
In [3]:
Copied!
def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
It's time to solve some ODEs:
In [4]:
Copied!
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
IVPSolution(t=(10,), u=[(10, 3), (10, 3), (10, 3)], u_std=[(10, 3), (10, 3), (10, 3)], output_scale=(9,), marginals=Normal(mean=(10, 9), cholesky=(10, 9, 9)), posterior=MarkovSeq(init=Normal(mean=(10, 9), cholesky=(10, 9, 9)), conditional=Conditional(matmul=(9, 9, 9), noise=Normal(mean=(9, 9), cholesky=(9, 9, 9)))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f3bc820c110>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f3bc820cc50>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f3bc82a8620>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f3bc82aa600>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f3bc82ab680>, transform=<probdiffeq.impl._conditional.DenseTransform object at 0x7f3bc82abfb0>))
The type of solution.u matches that of the initial condition.
In [5]:
Copied!
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
[(3,), (3,), (3,)] [(10, 3), (10, 3), (10, 3)]
Anything that behaves like a list work. For example, we can use lists or tuples, but also named tuples.
In [6]:
Copied!
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor(state=(3,), velocity=(3,), acceleration=(3,)) IVPSolution(t=(10,), u=Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)), u_std=Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)), output_scale=(9,), marginals=Normal(mean=(10, 9), cholesky=(10, 9, 9)), posterior=MarkovSeq(init=Normal(mean=(10, 9), cholesky=(10, 9, 9)), conditional=Conditional(matmul=(9, 9, 9), noise=Normal(mean=(9, 9), cholesky=(9, 9, 9)))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f3bc023c1d0>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f3bc8124cb0>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f3bc02e08f0>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f3bc02e16a0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f3bc02283b0>, transform=<probdiffeq.impl._conditional.DenseTransform object at 0x7f3bc0398ef0>)) Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3))
The same applies to statistical quantities that we can extract from the solution. For example, the standard deviation or samples from the solution object:
In [7]:
Copied!
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)) Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)) Taylor(state=(9, 3), velocity=(9, 3), acceleration=(9, 3)) Taylor(state=(3,), velocity=(3,), acceleration=(3,))