Simulate second-order systems¶
In [1]:
Copied!
"""Demonstrate how to solve second-order IVPs without transforming them first."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
"""Demonstrate how to solve second-order IVPs without transforming them first."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
In [2]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [3]:
Copied!
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
Quick refresher: first-order ODEs
In [4]:
Copied!
impl.select("isotropic", ode_shape=(4,))
f, u0, (t0, t1), f_args = ivps.three_body_restricted_first_order()
@jax.jit
def vf_1(y, t): # noqa: ARG001
"""Evaluate the three-body problem as a first-order IVP."""
return f(y, *f_args)
ibm = ivpsolvers.prior_ibm(num_derivatives=4)
ts0 = ivpsolvers.correction_ts0()
solver_1st = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0))
adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4)
init = solver_1st.initial_condition(tcoeffs, output_scale=1.0)
impl.select("isotropic", ode_shape=(4,))
f, u0, (t0, t1), f_args = ivps.three_body_restricted_first_order()
@jax.jit
def vf_1(y, t): # noqa: ARG001
"""Evaluate the three-body problem as a first-order IVP."""
return f(y, *f_args)
ibm = ivpsolvers.prior_ibm(num_derivatives=4)
ts0 = ivpsolvers.correction_ts0()
solver_1st = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0))
adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5)
tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4)
init = solver_1st.initial_condition(tcoeffs, output_scale=1.0)
In [5]:
Copied!
solution = ivpsolve.solve_adaptive_save_every_step(
vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st
)
solution = ivpsolve.solve_adaptive_save_every_step(
vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st
)
In [6]:
Copied!
norm = jnp.linalg.norm((solution.u[-1, ...] - u0) / jnp.abs(1.0 + u0))
plt.title((solution.u.shape, norm))
plt.plot(solution.u[:, 0], solution.u[:, 1], marker=".")
plt.show()
norm = jnp.linalg.norm((solution.u[-1, ...] - u0) / jnp.abs(1.0 + u0))
plt.title((solution.u.shape, norm))
plt.plot(solution.u[:, 0], solution.u[:, 1], marker=".")
plt.show()
The default configuration assumes that the ODE to be solved is of first order. Now, the same game with a second-order ODE
In [7]:
Copied!
impl.select("isotropic", ode_shape=(2,))
f, (u0, du0), (t0, t1), f_args = ivps.three_body_restricted()
@jax.jit
def vf_2(y, dy, t): # noqa: ARG001
"""Evaluate the three-body problem as a second-order IVP."""
return f(y, dy, *f_args)
# One derivative more than above because we don't transform to first order
ibm = ivpsolvers.prior_ibm(num_derivatives=4)
ts0 = ivpsolvers.correction_ts0(ode_order=2)
solver_2nd = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0))
adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5)
tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)
init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0)
impl.select("isotropic", ode_shape=(2,))
f, (u0, du0), (t0, t1), f_args = ivps.three_body_restricted()
@jax.jit
def vf_2(y, dy, t): # noqa: ARG001
"""Evaluate the three-body problem as a second-order IVP."""
return f(y, dy, *f_args)
# One derivative more than above because we don't transform to first order
ibm = ivpsolvers.prior_ibm(num_derivatives=4)
ts0 = ivpsolvers.correction_ts0(ode_order=2)
solver_2nd = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0))
adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5)
tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)
init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0)
In [8]:
Copied!
solution = ivpsolve.solve_adaptive_save_every_step(
vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd
)
solution = ivpsolve.solve_adaptive_save_every_step(
vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd
)
In [9]:
Copied!
norm = jnp.linalg.norm((solution.u[-1, ...] - u0) / jnp.abs(1.0 + u0))
plt.title((solution.u.shape, norm))
plt.plot(solution.u[:, 0], solution.u[:, 1], marker=".")
plt.show()
norm = jnp.linalg.norm((solution.u[-1, ...] - u0) / jnp.abs(1.0 + u0))
plt.title((solution.u.shape, norm))
plt.plot(solution.u[:, 0], solution.u[:, 1], marker=".")
plt.show()
The results are indistinguishable from the plot. While the runtimes of both solvers are similar, the error of the second-order solver is much lower.
See the benchmarks for more quantitative versions of this statement.