In [1]:
Copied!
"""Use Equinox's while loop to compute gradients of `simulate_terminal_values`."""
import equinox
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import control_flow
"""Use Equinox's while loop to compute gradients of `simulate_terminal_values`."""
import equinox
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import control_flow
Overwrite the while-loop (via a context manager):
In [2]:
Copied!
def while_loop_func(*a, **kw):
    """Evaluate a bounded while loop."""
    return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)
context_compute_gradient = control_flow.context_overwrite_while_loop(while_loop_func)
def while_loop_func(*a, **kw):
    """Evaluate a bounded while loop."""
    return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)
context_compute_gradient = control_flow.context_overwrite_while_loop(while_loop_func)
The rest is the similar to the "easy example" in the quickstart, except for simulating adaptively and computing the value and the gradient (which is impossible without the specialised while-loop implementation).
In [3]:
Copied!
def solution_routine():
    """Construct a parameter-to-solution function and an initial value."""
    def vf(y, *, t):  # noqa: ARG001
        """Evaluate the vector field."""
        return 0.5 * y * (1 - y)
    t0, t1 = 0.0, 1.0
    u0 = jnp.asarray([0.1])
    tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
    init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
    ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm)
    strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
    solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
    adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
    def simulate(init_val):
        """Evaluate the parameter-to-solution function."""
        sol = ivpsolve.solve_adaptive_terminal_values(
            init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
        )
        # Any scalar function of the IVP solution would do
        return jnp.dot(sol.u[0], sol.u[0])
    return simulate, init
def solution_routine():
    """Construct a parameter-to-solution function and an initial value."""
    def vf(y, *, t):  # noqa: ARG001
        """Evaluate the vector field."""
        return 0.5 * y * (1 - y)
    t0, t1 = 0.0, 1.0
    u0 = jnp.asarray([0.1])
    tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
    init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
    ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm)
    strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
    solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
    adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
    def simulate(init_val):
        """Evaluate the parameter-to-solution function."""
        sol = ivpsolve.solve_adaptive_terminal_values(
            init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
        )
        # Any scalar function of the IVP solution would do
        return jnp.dot(sol.u[0], sol.u[0])
    return simulate, init
In [4]:
Copied!
try:
    solve, x = solution_routine()
    solution, gradient = jax.value_and_grad(solve)(x)
except ValueError as err:
    print(f"Caught error:\n\t {err}")
try:
    solve, x = solution_routine()
    solution, gradient = jax.value_and_grad(solve)(x)
except ValueError as err:
    print(f"Caught error:\n\t {err}")
Caught error: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
In [5]:
Copied!
with context_compute_gradient:
    # Construct the solution routine inside the context
    solve, x = solution_routine()
    # Compute gradients
    solution, gradient = jax.value_and_grad(solve)(x)
    print(solution)
    print(gradient)
with context_compute_gradient:
    # Construct the solution routine inside the context
    solve, x = solution_routine()
    # Compute gradients
    solution, gradient = jax.value_and_grad(solve)(x)
    print(solution)
    print(gradient)
0.023939388
Normal(mean=Array([[0.4424412 ],
       [0.01854868]], dtype=float32), cholesky=Array([[0., 0.],
       [0., 0.]], dtype=float32))