A0. Get started¶
Set up the complete solver pipeline step by step: define an ODE, expand Taylor coefficients, build a prior, a strategy, a constraint, and an error estimator, then call the adaptive solver.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq import ivpsolve, probdiffeq
In [3]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [4]:
Copied!
def main():
"""Solve the logistic equation."""
# Define a differential equation
@probdiffeq.ode
def vf(y, /, *, t):
"""Evaluate the dynamics of the logistic ODE."""
del t # unused argument
return 2 * y * (1 - y)
u0 = jnp.asarray(0.1)
t0, t1 = 0.0, 5.0
# Construct a state-space model factorisation and the constraint
ssm = probdiffeq.state_space_model_dense()
# Initialize Taylor coefficients and construct the prior
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=2)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
prior = ssm.prior_wiener_integrated(tcoeffs)
# Build the rest of the solver
strategy = probdiffeq.strategy_filter()
constraint = ssm.constraint_ode_ts1(vf)
solver = probdiffeq.solver_mle(strategy=strategy, constraint=constraint)
error = probdiffeq.error_residual_std(constraint=constraint)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
# Solve the ODE. Try different solution routines.
save_at = jnp.linspace(t0, t1, num=100, endpoint=True)
solution = jax.jit(solve)(prior, save_at, atol=1e-3, rtol=1e-3)
print(f"\ninitial = {jax.tree.map(jnp.shape, prior)}")
print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}")
def main():
"""Solve the logistic equation."""
# Define a differential equation
@probdiffeq.ode
def vf(y, /, *, t):
"""Evaluate the dynamics of the logistic ODE."""
del t # unused argument
return 2 * y * (1 - y)
u0 = jnp.asarray(0.1)
t0, t1 = 0.0, 5.0
# Construct a state-space model factorisation and the constraint
ssm = probdiffeq.state_space_model_dense()
# Initialize Taylor coefficients and construct the prior
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=2)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
prior = ssm.prior_wiener_integrated(tcoeffs)
# Build the rest of the solver
strategy = probdiffeq.strategy_filter()
constraint = ssm.constraint_ode_ts1(vf)
solver = probdiffeq.solver_mle(strategy=strategy, constraint=constraint)
error = probdiffeq.error_residual_std(constraint=constraint)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
# Solve the ODE. Try different solution routines.
save_at = jnp.linspace(t0, t1, num=100, endpoint=True)
solution = jax.jit(solve)(prior, save_at, atol=1e-3, rtol=1e-3)
print(f"\ninitial = {jax.tree.map(jnp.shape, prior)}")
print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}")
In [5]:
Copied!
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
initial = DenseWienerIntegrated(init=DenseNormal((3,), (3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f10a4aaf980>)), output_scale=(1, 1)) solution = ProbabilisticSolution(t=(100,), u=DenseNormal((100, 3), (100, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f10a4aaf980>)), solution_full=DenseNormal((100, 3), (100, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f10a4aaf980>)), output_scale=(99,), num_steps=(99,), auxiliary=((99, 2), (99,), (99,)), fun_evals=DenseLatentCond(A=(99, 1, 3), noise=DenseNormal((99, 1), (99, 1, 1), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f10a4df9370>)), to_latent=(99, 3), to_observed=(99, 1)), prior=DenseWienerIntegrated(init=DenseNormal((99, 3), (99, 3, 3), DenseTreeFlatten(unravel=<jax._src.util.HashablePartial object at 0x7f10a4aaf980>)), output_scale=(99, 1, 1)))