A1. Taylor coefficients as central data structures¶
Taylor coefficients are passed through the entire solver pipeline. Any sequence type works: lists, tuples, and named tuples all behave identically. Statistical outputs inherit the same container structure as the initial coefficients: means, standard deviations, and posterior samples all match the type and field names of the Taylor coefficient container.
In [1]:
Copied!
import collections
import collections
In [2]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [3]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
def main():
"""Explore different Taylor coefficients."""
# We start by defining an ODE.
@probdiffeq.ode
def vf(y, /, *, t):
"""Evaluate the dynamics of the logistic ODE."""
del t # unused argument
return 2 * y * (1 - y)
u0 = jnp.asarray(0.1)
t0, t1 = 0.0, 5.0
# Solve with a list of Taylor coefficients:
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=2)
tcoeffs_list, _ = jetexpand(vf, (u0,), t=t0)
solution = jax.jit(solve, static_argnums=[0])(vf, tcoeffs_list, t0=t0, t1=t1)
print()
print("Probabilistic solution:")
print(jax.tree.map(jnp.shape, solution))
# The type of solution.u matches that of the initial condition.
print()
print("Solution matches initial condition:")
print(jax.tree.map(jnp.shape, tcoeffs_list))
print(jax.tree.map(jnp.shape, solution.u))
# Anything that behaves like a list works.
# For example, we can use lists or tuples, but also named tuples.
CustomTCoeffs = collections.namedtuple(
"CustomTCoeffs", ["state", "velocity", "acceleration"]
)
tcoeffs = CustomTCoeffs(*tcoeffs_list)
solution = jax.jit(solve, static_argnums=[0])(vf, tcoeffs, t0=t0, t1=t1)
print()
print("The target is a named tuple:")
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
# The same applies to statistical quantities that we can extract from the solution.
# For example, the standard deviation or samples from the solution object:
key = jax.random.PRNGKey(seed=15)
posterior = solution.solution_full
sample_one = posterior.sample(key)
sample_many = posterior.sample(key, shape=(1, 2, 3))
print()
print("Samples inherit structure:")
print(jax.tree.map(jnp.shape, solution.u.mean))
print(jax.tree.map(jnp.shape, solution.u.std))
print(jax.tree.map(jnp.shape, sample_one))
print(jax.tree.map(jnp.shape, sample_many))
assert isinstance(solution.u.mean, CustomTCoeffs)
assert isinstance(solution.u.std, CustomTCoeffs)
assert isinstance(sample_one, CustomTCoeffs)
assert isinstance(sample_many, CustomTCoeffs)
def main():
"""Explore different Taylor coefficients."""
# We start by defining an ODE.
@probdiffeq.ode
def vf(y, /, *, t):
"""Evaluate the dynamics of the logistic ODE."""
del t # unused argument
return 2 * y * (1 - y)
u0 = jnp.asarray(0.1)
t0, t1 = 0.0, 5.0
# Solve with a list of Taylor coefficients:
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=2)
tcoeffs_list, _ = jetexpand(vf, (u0,), t=t0)
solution = jax.jit(solve, static_argnums=[0])(vf, tcoeffs_list, t0=t0, t1=t1)
print()
print("Probabilistic solution:")
print(jax.tree.map(jnp.shape, solution))
# The type of solution.u matches that of the initial condition.
print()
print("Solution matches initial condition:")
print(jax.tree.map(jnp.shape, tcoeffs_list))
print(jax.tree.map(jnp.shape, solution.u))
# Anything that behaves like a list works.
# For example, we can use lists or tuples, but also named tuples.
CustomTCoeffs = collections.namedtuple(
"CustomTCoeffs", ["state", "velocity", "acceleration"]
)
tcoeffs = CustomTCoeffs(*tcoeffs_list)
solution = jax.jit(solve, static_argnums=[0])(vf, tcoeffs, t0=t0, t1=t1)
print()
print("The target is a named tuple:")
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
# The same applies to statistical quantities that we can extract from the solution.
# For example, the standard deviation or samples from the solution object:
key = jax.random.PRNGKey(seed=15)
posterior = solution.solution_full
sample_one = posterior.sample(key)
sample_many = posterior.sample(key, shape=(1, 2, 3))
print()
print("Samples inherit structure:")
print(jax.tree.map(jnp.shape, solution.u.mean))
print(jax.tree.map(jnp.shape, solution.u.std))
print(jax.tree.map(jnp.shape, sample_one))
print(jax.tree.map(jnp.shape, sample_many))
assert isinstance(solution.u.mean, CustomTCoeffs)
assert isinstance(solution.u.std, CustomTCoeffs)
assert isinstance(sample_one, CustomTCoeffs)
assert isinstance(sample_many, CustomTCoeffs)
In [6]:
Copied!
def solve(vf, tc, *, t0, t1):
"""Solve the ODE."""
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_wiener_integrated(tc)
ts0 = ssm.constraint_ode_ts0(vf)
strategy = probdiffeq.strategy_smoother_fixedpoint()
solver = probdiffeq.solver_mle(strategy=strategy, constraint=ts0)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
error = probdiffeq.error_residual_std(constraint=ts0)
solve_fn = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
return solve_fn(prior, save_at=ts, atol=1e-2, rtol=1e-2)
def solve(vf, tc, *, t0, t1):
"""Solve the ODE."""
ssm = probdiffeq.state_space_model_dense()
prior = ssm.prior_wiener_integrated(tc)
ts0 = ssm.constraint_ode_ts0(vf)
strategy = probdiffeq.strategy_smoother_fixedpoint()
solver = probdiffeq.solver_mle(strategy=strategy, constraint=ts0)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
error = probdiffeq.error_residual_std(constraint=ts0)
solve_fn = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
return solve_fn(prior, save_at=ts, atol=1e-2, rtol=1e-2)
In [7]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
Probabilistic solution: ProbabilisticSolution(t=(10,), u=DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6b12c3e780>)), solution_full=MarkovSequence(marginal=DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6b12c3e780>)), conditional=DenseLatentCond(A=(9, 3, 3), noise=DenseNormal((9, 3), (9, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6b12c3e780>)), to_latent=(9, 3), to_observed=(9, 3)), reverse=True), output_scale=(9,), num_steps=(9,), auxiliary=(None, (9,), (9,)), fun_evals=DenseLatentCond(A=(9, 1, 3), noise=DenseNormal((9, 1), (9, 1, 1), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6ae07a7800>)), to_latent=(9, 3), to_observed=(9, 1)), prior=DenseWienerIntegrated(init=DenseNormal((9, 3), (9, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6b12c3e780>)), output_scale=(9, 1, 1))) Solution matches initial condition: [(), (), ()] DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6b12c3e780>))
The target is a named tuple: CustomTCoeffs(state=(), velocity=(), acceleration=()) ProbabilisticSolution(t=(10,), u=DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6aa1f3ff80>)), solution_full=MarkovSequence(marginal=DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6aa1f3ff80>)), conditional=DenseLatentCond(A=(9, 3, 3), noise=DenseNormal((9, 3), (9, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6aa1f3ff80>)), to_latent=(9, 3), to_observed=(9, 3)), reverse=True), output_scale=(9,), num_steps=(9,), auxiliary=(None, (9,), (9,)), fun_evals=DenseLatentCond(A=(9, 1, 3), noise=DenseNormal((9, 1), (9, 1, 1), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6ab0150530>)), to_latent=(9, 3), to_observed=(9, 1)), prior=DenseWienerIntegrated(init=DenseNormal((9, 3), (9, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6aa1f3ff80>)), output_scale=(9, 1, 1))) DenseNormal((10, 3), (10, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f6aa1f3ff80>))
Samples inherit structure: CustomTCoeffs(state=(10,), velocity=(10,), acceleration=(10,)) CustomTCoeffs(state=(10,), velocity=(10,), acceleration=(10,)) CustomTCoeffs(state=(10,), velocity=(10,), acceleration=(10,)) CustomTCoeffs(state=(1, 2, 3, 10), velocity=(1, 2, 3, 10), acceleration=(1, 2, 3, 10))