Quickstart¶
Let's have a look at an easy example.
In [1]:
Copied!
"""Solve the logistic equation."""
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
# Define a differential equation
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the dynamics of the logistic ODE."""
return 2 * y * (1 - y)
u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 5.0
# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
# Build a solver
ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
# Solve the ODE
# To all users: Try different solution routines.
solution = ivpsolve.solve_adaptive_save_every_step(
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
# Look at the solution
print(f"\ninitial = {jax.tree.map(jnp.shape, init)}")
print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}")
"""Solve the logistic equation."""
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
# Define a differential equation
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the dynamics of the logistic ODE."""
return 2 * y * (1 - y)
u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 5.0
# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
# Build a solver
ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
# Solve the ODE
# To all users: Try different solution routines.
solution = ivpsolve.solve_adaptive_save_every_step(
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
# Look at the solution
print(f"\ninitial = {jax.tree.map(jnp.shape, init)}")
print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}")
initial = Normal(mean=(2,), cholesky=(2, 2)) solution = IVPSolution(t=(38,), u=[(38, 1), (38, 1)], u_std=[(38, 1), (38, 1)], output_scale=(37,), marginals=Normal(mean=(38, 2), cholesky=(38, 2, 2)), posterior=Normal(mean=(38, 2), cholesky=(38, 2, 2)), num_steps=(37,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x7f1010ec19d0>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x7f1010ec1a00>, stats=<probdiffeq.impl._stats.DenseStats object at 0x7f1010ec1880>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x7f1010ec19a0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x7f1010ec1970>, num_derivatives=1, unravel=<jax._src.util.HashablePartial object at 0x7f1015bc0260>))