Understanding the role of Taylor coefficients¶
To build a probabilistic solver, we need to build a specific state-space model. To build this specific state-space model, we interact with Taylor coefficients. Here are some examples how Taylor coefficients play a role in Probdiffeq's solution routines.
In [1]:
Copied!
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.util.doc_util import notebook
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
from probdiffeq.util.doc_util import notebook
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
We start by defining an ODE.
In [2]:
Copied!
f, u0, (t0, t1), f_args = ivps.rigid_body()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
f, u0, (t0, t1), f_args = ivps.rigid_body()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
Here is a wrapper arounds Probdiffeq's solution routine.
In [3]:
Copied!
def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
It's time to solve some ODEs:
In [4]:
Copied!
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
IVPSolution(t=(10,), u=[(10, 3), (10, 3), (10, 3)], u_std=[(10, 3), (10, 3), (10, 3)], output_scale=(9,), marginals=Normal(mean=(10, 9), cholesky=(10, 9, 9)), posterior=MarkovSeq(init=Normal(mean=(10, 9), cholesky=(10, 9, 9)), conditional=Conditional(matmul=(9, 9, 9), noise=Normal(mean=(9, 9), cholesky=(9, 9, 9)))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f7db979cf50>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f7dd5ffd310>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f7db97e0680>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f7dc6d22420>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f7db97e32c0>, transform=<probdiffeq.impl._conditional.DenseTransform object at 0x7f7db97be600>))
The type of solution.u matches that of the initial condition.
In [5]:
Copied!
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
[(3,), (3,), (3,)] [(10, 3), (10, 3), (10, 3)]
Anything that behaves like a list work. For example, we can use lists or tuples, but also named tuples.
In [6]:
Copied!
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor(state=(3,), velocity=(3,), acceleration=(3,)) IVPSolution(t=(10,), u=Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)), u_std=Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)), output_scale=(9,), marginals=Normal(mean=(10, 9), cholesky=(10, 9, 9)), posterior=MarkovSeq(init=Normal(mean=(10, 9), cholesky=(10, 9, 9)), conditional=Conditional(matmul=(9, 9, 9), noise=Normal(mean=(9, 9), cholesky=(9, 9, 9)))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f7db96a2ba0>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f7db97bd4c0>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f7de0141160>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f7db97bede0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f7de01428d0>, transform=<probdiffeq.impl._conditional.DenseTransform object at 0x7f7de0142480>)) Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3))
The same applies to statistical quantities that we can extract from the solution. For example, the standard deviation or samples from the solution object:
In [7]:
Copied!
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)) Taylor(state=(10, 3), velocity=(10, 3), acceleration=(10, 3)) Taylor(state=(9, 3), velocity=(9, 3), acceleration=(9, 3)) Taylor(state=(3,), velocity=(3,), acceleration=(3,))