Diffusion tempering & NODEs¶
In [1]:
Copied!
"""Train a neural ODE with ProbDiffEq and Optax using diffusion tempering."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from probdiffeq import ivpsolve, ivpsolvers, stats
def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2):
    """Train a neural ODE using diffusion tempering."""
    # Create some data and construct a neural ODE
    grid = jnp.linspace(0, 1, num=num_data)
    data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid
    stdev = 1e-1
    output_scale = 1e2
    vf, u0, (t0, t1), f_args = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)
    # Create a loss (this is where probabilistic numerics enters!)
    loss = loss_log_marginal_likelihood(vf=vf, t0=t0)
    loss0, info0 = loss(
        f_args, u0=u0, grid=grid, data=data, stdev=stdev, output_scale=output_scale
    )
    # Plot the data and the initial guess
    plt.title(f"Initial estimate | Loss: {loss0:.2f}")
    plt.plot(grid, data, "x", label="Data", color="C0")
    plt.plot(grid, info0["sol"].u[0], "-", label="Estimate", color="C1")
    plt.legend()
    plt.show()
    # Construct an optimiser
    optim = optax.adam(lr)
    train_step = train_step_optax(optim, loss=loss)
    # Train the model
    state = optim.init(f_args)
    print("Loss after...")
    for i in range(epochs):
        (f_args, state), info = train_step(
            f_args,
            state,
            u0=u0,
            grid=grid,
            data=data,
            stdev=stdev,
            output_scale=output_scale,
        )
        # Print progressbar
        if i % print_every == print_every - 1:
            print(f"...{(i + 1)} epochs: loss={info['loss']:.3e}")
        # Diffusion tempering: https://arxiv.org/abs/2402.12231
        # To all users: Adjust this tempering and
        # see how it affects parameter estimation.
        if i % 100 == 0:
            output_scale /= 10.0
    # Plot the results
    plt.title(f"Final estimate | Loss: {info['loss']:.2f}")
    plt.plot(grid, data, "x", label="Data", color="C0")
    plt.plot(grid, info0["sol"].u[0], "-", label="Initial estimate", color="C1")
    plt.plot(grid, info["sol"].u[0], "-", label="Final estimate", color="C2")
    plt.legend()
    plt.show()
def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
    """Build a neural ODE."""
    f_args, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
    u0 = jnp.asarray([0.0])
    @jax.jit
    def vf(y, *, t, p):
        """Evaluate the neural ODE vector field."""
        y_and_t = jnp.concatenate([y, t[None]])
        return mlp(p, y_and_t)
    return vf, (u0,), (t0, t1), f_args
def model_mlp(
    *, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
):
    """Construct an MLP."""
    assert len(shape_in) <= 1
    assert len(shape_out) <= 1
    shape_prev = shape_in
    weights = []
    for h in hidden:
        W = jnp.empty((h, *shape_prev))
        b = jnp.empty((h,))
        shape_prev = (h,)
        weights.append((W, b))
    W = jnp.empty((*shape_out, *shape_prev))
    b = jnp.empty(shape_out)
    weights.append((W, b))
    p_flat, unravel = jax.flatten_util.ravel_pytree(weights)
    def fwd(w, x):
        for A, b in w[:-1]:
            x = jnp.dot(A, x) + b
            x = activation(x)
        A, b = w[-1]
        return jnp.dot(A, x) + b
    key = jax.random.PRNGKey(1)
    p_init = jax.random.normal(key, shape=p_flat.shape, dtype=p_flat.dtype)
    return unravel(p_init), fwd
def loss_log_marginal_likelihood(vf, *, t0):
    """Build a loss function from an ODE problem."""
    @jax.jit
    def loss(
        p: jax.Array,
        *,
        u0: tuple,
        grid: jax.Array,
        data: jax.Array,
        stdev: jax.Array,
        output_scale: jax.Array,
    ):
        """Loss function: log-marginal likelihood of the data."""
        # Build a solver
        tcoeffs = (*u0, vf(*u0, t=t0, p=p))
        init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
            tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
        )
        ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm)
        strategy = ivpsolvers.strategy_smoother(ssm=ssm)
        solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
        # Solve
        sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm)
        # Evaluate loss
        marginal_likelihood = stats.log_marginal_likelihood(
            data[:, None],
            standard_deviation=jnp.ones_like(grid) * stdev,
            posterior=sol.posterior,
            ssm=sol.ssm,
        )
        return -1 * marginal_likelihood, {"sol": sol}
    return loss
def train_step_optax(optimizer, loss):
    """Implement a training step using Optax."""
    @jax.jit
    def update(params, opt_state, **loss_kwargs):
        """Update the optimiser state."""
        value_and_grad = jax.value_and_grad(loss, argnums=0, has_aux=True)
        (value, info), grads = value_and_grad(params, **loss_kwargs)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        info["loss"] = value
        return (params, opt_state), info
    return update
if __name__ == "__main__":
    main()
"""Train a neural ODE with ProbDiffEq and Optax using diffusion tempering."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from probdiffeq import ivpsolve, ivpsolvers, stats
def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2):
    """Train a neural ODE using diffusion tempering."""
    # Create some data and construct a neural ODE
    grid = jnp.linspace(0, 1, num=num_data)
    data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid
    stdev = 1e-1
    output_scale = 1e2
    vf, u0, (t0, t1), f_args = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)
    # Create a loss (this is where probabilistic numerics enters!)
    loss = loss_log_marginal_likelihood(vf=vf, t0=t0)
    loss0, info0 = loss(
        f_args, u0=u0, grid=grid, data=data, stdev=stdev, output_scale=output_scale
    )
    # Plot the data and the initial guess
    plt.title(f"Initial estimate | Loss: {loss0:.2f}")
    plt.plot(grid, data, "x", label="Data", color="C0")
    plt.plot(grid, info0["sol"].u[0], "-", label="Estimate", color="C1")
    plt.legend()
    plt.show()
    # Construct an optimiser
    optim = optax.adam(lr)
    train_step = train_step_optax(optim, loss=loss)
    # Train the model
    state = optim.init(f_args)
    print("Loss after...")
    for i in range(epochs):
        (f_args, state), info = train_step(
            f_args,
            state,
            u0=u0,
            grid=grid,
            data=data,
            stdev=stdev,
            output_scale=output_scale,
        )
        # Print progressbar
        if i % print_every == print_every - 1:
            print(f"...{(i + 1)} epochs: loss={info['loss']:.3e}")
        # Diffusion tempering: https://arxiv.org/abs/2402.12231
        # To all users: Adjust this tempering and
        # see how it affects parameter estimation.
        if i % 100 == 0:
            output_scale /= 10.0
    # Plot the results
    plt.title(f"Final estimate | Loss: {info['loss']:.2f}")
    plt.plot(grid, data, "x", label="Data", color="C0")
    plt.plot(grid, info0["sol"].u[0], "-", label="Initial estimate", color="C1")
    plt.plot(grid, info["sol"].u[0], "-", label="Final estimate", color="C2")
    plt.legend()
    plt.show()
def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
    """Build a neural ODE."""
    f_args, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
    u0 = jnp.asarray([0.0])
    @jax.jit
    def vf(y, *, t, p):
        """Evaluate the neural ODE vector field."""
        y_and_t = jnp.concatenate([y, t[None]])
        return mlp(p, y_and_t)
    return vf, (u0,), (t0, t1), f_args
def model_mlp(
    *, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
):
    """Construct an MLP."""
    assert len(shape_in) <= 1
    assert len(shape_out) <= 1
    shape_prev = shape_in
    weights = []
    for h in hidden:
        W = jnp.empty((h, *shape_prev))
        b = jnp.empty((h,))
        shape_prev = (h,)
        weights.append((W, b))
    W = jnp.empty((*shape_out, *shape_prev))
    b = jnp.empty(shape_out)
    weights.append((W, b))
    p_flat, unravel = jax.flatten_util.ravel_pytree(weights)
    def fwd(w, x):
        for A, b in w[:-1]:
            x = jnp.dot(A, x) + b
            x = activation(x)
        A, b = w[-1]
        return jnp.dot(A, x) + b
    key = jax.random.PRNGKey(1)
    p_init = jax.random.normal(key, shape=p_flat.shape, dtype=p_flat.dtype)
    return unravel(p_init), fwd
def loss_log_marginal_likelihood(vf, *, t0):
    """Build a loss function from an ODE problem."""
    @jax.jit
    def loss(
        p: jax.Array,
        *,
        u0: tuple,
        grid: jax.Array,
        data: jax.Array,
        stdev: jax.Array,
        output_scale: jax.Array,
    ):
        """Loss function: log-marginal likelihood of the data."""
        # Build a solver
        tcoeffs = (*u0, vf(*u0, t=t0, p=p))
        init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
            tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
        )
        ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm)
        strategy = ivpsolvers.strategy_smoother(ssm=ssm)
        solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
        # Solve
        sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm)
        # Evaluate loss
        marginal_likelihood = stats.log_marginal_likelihood(
            data[:, None],
            standard_deviation=jnp.ones_like(grid) * stdev,
            posterior=sol.posterior,
            ssm=sol.ssm,
        )
        return -1 * marginal_likelihood, {"sol": sol}
    return loss
def train_step_optax(optimizer, loss):
    """Implement a training step using Optax."""
    @jax.jit
    def update(params, opt_state, **loss_kwargs):
        """Update the optimiser state."""
        value_and_grad = jax.value_and_grad(loss, argnums=0, has_aux=True)
        (value, info), grads = value_and_grad(params, **loss_kwargs)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        info["loss"] = value
        return (params, opt_state), info
    return update
if __name__ == "__main__":
    main()
Loss after...
...100 epochs: loss=2.420e+01 ...200 epochs: loss=1.774e+01
...300 epochs: loss=2.794e+00 ...400 epochs: loss=4.559e+01
...500 epochs: loss=1.050e+01 ...600 epochs: loss=1.653e-01
...700 epochs: loss=-1.223e+00 ...800 epochs: loss=-1.297e+00
...900 epochs: loss=-1.335e+00 ...1000 epochs: loss=-1.248e+00