Taylor coefficients¶
To build a probabilistic solver, we need to build a specific state-space model. To build this specific state-space model, we interact with Taylor coefficients. Here are some examples how Taylor coefficients play a role in Probdiffeq's solution routines.
In [1]:
Copied!
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
We start by defining an ODE.
In [2]:
Copied!
f, u0, (t0, t1), f_args = ivps.logistic()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
f, u0, (t0, t1), f_args = ivps.logistic()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
Here is a wrapper arounds Probdiffeq's solution routine.
In [3]:
Copied!
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
It's time to solve some ODEs:
In [4]:
Copied!
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
IVPSolution(t=(10,), u=[(10,), (10,), (10,)], u_std=[(10,), (10,), (10,)], output_scale=(9,), marginals=Normal(mean=(10, 3), cholesky=(10, 3, 3)), posterior=MarkovSeq(init=Normal(mean=(10, 3), cholesky=(10, 3, 3)), conditional=LatentCond(A=(9, 3, 3), noise=Normal(mean=(9, 3), cholesky=(9, 3, 3)), to_latent=(9, 3), to_observed=(9, 3))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f1310fd9580>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f1310ff1fa0>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f1315f01a30>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f1310ff3aa0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f1310219c40>, num_derivatives=2, unravel=<jax._src.util.HashablePartial object at 0x7f1310f73fb0>))
The type of solution.u matches that of the initial condition.
In [5]:
Copied!
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
[(), (), ()] [(10,), (10,), (10,)]
Anything that behaves like a list work. For example, we can use lists or tuples, but also named tuples.
In [6]:
Copied!
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor(state=(), velocity=(), acceleration=()) IVPSolution(t=(10,), u=Taylor(state=(10,), velocity=(10,), acceleration=(10,)), u_std=Taylor(state=(10,), velocity=(10,), acceleration=(10,)), output_scale=(9,), marginals=Normal(mean=(10, 3), cholesky=(10, 3, 3)), posterior=MarkovSeq(init=Normal(mean=(10, 3), cholesky=(10, 3, 3)), conditional=LatentCond(A=(9, 3, 3), noise=Normal(mean=(9, 3), cholesky=(9, 3, 3)), to_latent=(9, 3), to_observed=(9, 3))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f13100fe960>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f12f8af0c80>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f13101ef860>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f13100fc6b0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f12f8b70740>, num_derivatives=2, unravel=<jax._src.util.HashablePartial object at 0x7f12f8bf6690>)) Taylor(state=(10,), velocity=(10,), acceleration=(10,))
The same applies to statistical quantities that we can extract from the solution. For example, the standard deviation or samples from the solution object:
In [7]:
Copied!
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
Taylor(state=(10,), velocity=(10,), acceleration=(10,)) Taylor(state=(10,), velocity=(10,), acceleration=(10,)) Taylor(state=(9,), velocity=(9,), acceleration=(9,)) Taylor(state=(), velocity=(), acceleration=())