Taylor coefficients¶
To build a probabilistic solver, we need to build a specific state-space model. To build this specific state-space model, we interact with Taylor coefficients. Here are some examples how Taylor coefficients play a role in Probdiffeq's solution routines.
In [1]:
Copied!
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
We start by defining an ODE.
In [2]:
Copied!
f, u0, (t0, t1), f_args = ivps.logistic()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
f, u0, (t0, t1), f_args = ivps.logistic()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
Here is a wrapper arounds Probdiffeq's solution routine.
In [3]:
Copied!
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
It's time to solve some ODEs:
In [4]:
Copied!
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
IVPSolution(t=(10,), u=[(10,), (10,), (10,)], u_std=[(10,), (10,), (10,)], output_scale=(9,), marginals=Normal(mean=(10, 3), cholesky=(10, 3, 3)), posterior=MarkovSeq(init=Normal(mean=(10, 3), cholesky=(10, 3, 3)), conditional=LatentCond(A=(9, 3, 3), noise=Normal(mean=(9, 3), cholesky=(9, 3, 3)), to_latent=(9, 3), to_observed=(9, 3))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7fdd35759820>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7fdd359c11f0>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7fdd30086150>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7fdd30a66420>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7fdd30084ad0>, num_derivatives=2, unravel=<jax._src.util.HashablePartial object at 0x7fdd351a0c50>))
The type of solution.u matches that of the initial condition.
In [5]:
Copied!
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
[(), (), ()] [(10,), (10,), (10,)]
Anything that behaves like a list work. For example, we can use lists or tuples, but also named tuples.
In [6]:
Copied!
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
Taylor(state=(), velocity=(), acceleration=()) IVPSolution(t=(10,), u=Taylor(state=(10,), velocity=(10,), acceleration=(10,)), u_std=Taylor(state=(10,), velocity=(10,), acceleration=(10,)), output_scale=(9,), marginals=Normal(mean=(10, 3), cholesky=(10, 3, 3)), posterior=MarkovSeq(init=Normal(mean=(10, 3), cholesky=(10, 3, 3)), conditional=LatentCond(A=(9, 3, 3), noise=Normal(mean=(9, 3), cholesky=(9, 3, 3)), to_latent=(9, 3), to_observed=(9, 3))), num_steps=(9,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7fdd1062f6b0>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7fdd1062d730>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7fdd351a2660>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7fdd8961bc80>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7fdd351a3500>, num_derivatives=2, unravel=<jax._src.util.HashablePartial object at 0x7fdd10399af0>)) Taylor(state=(10,), velocity=(10,), acceleration=(10,))
The same applies to statistical quantities that we can extract from the solution. For example, the standard deviation or samples from the solution object:
In [7]:
Copied!
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
Taylor(state=(10,), velocity=(10,), acceleration=(10,)) Taylor(state=(10,), velocity=(10,), acceleration=(10,)) Taylor(state=(9,), velocity=(9,), acceleration=(9,)) Taylor(state=(), velocity=(), acceleration=())