C0. Taylor recursions | FitzHugh-Nagumo¶
In [1]:
Copied!
import functools
from collections.abc import Callable
import functools
from collections.abc import Callable
In [2]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
In [3]:
Copied!
from probdiffeq import probdiffeq
from probdiffeq.util import benchmark_util
from probdiffeq import probdiffeq
from probdiffeq.util import benchmark_util
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
def main(max_time=0.25, repeats=2) -> None:
"""Run the script."""
# Set JAX config
jax.config.update("jax_enable_x64", True)
algorithms = {
r"Forward-mode": odejet_via_jvp(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}
# Compute a reference solution
timeit_fun = benchmark_util.setup_timeit(repeats=repeats)
# Compute all work-precision diagrams
results = {}
for label, algo in algorithms.items():
print("\n")
print(label)
results[label] = benchmark_util.adaptive_benchmark(
algo, timeit_fun=timeit_fun, max_time=max_time
)
_fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, figsize=(8, 3), dpi=150, sharex=True, sharey=True
)
for label, wp in results.items():
inputs = wp["arguments"]
work_compile = wp["work_compile"]
work_mean, work_std = wp["work_mean"], wp["work_std"]
if "doubling" in label:
num_repeats = jnp.diff(jnp.concatenate((jnp.ones((1,)), inputs)))
inputs = jnp.arange(1, jnp.amax(inputs) * 1)
work_compile = benchmark_util.adaptive_repeat(work_compile, num_repeats)
work_mean = benchmark_util.adaptive_repeat(work_mean, num_repeats)
work_std = benchmark_util.adaptive_repeat(work_std, num_repeats)
axis_compile.semilogy(inputs, work_compile, label=label)
axis_perform.semilogy(inputs, work_mean, label=label)
axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_perform.legend(fontsize="small")
axis_compile.legend(fontsize="small")
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid(linestyle="dotted")
axis_compile.grid(linestyle="dotted")
plt.tight_layout()
plt.show()
def main(max_time=0.25, repeats=2) -> None:
"""Run the script."""
# Set JAX config
jax.config.update("jax_enable_x64", True)
algorithms = {
r"Forward-mode": odejet_via_jvp(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}
# Compute a reference solution
timeit_fun = benchmark_util.setup_timeit(repeats=repeats)
# Compute all work-precision diagrams
results = {}
for label, algo in algorithms.items():
print("\n")
print(label)
results[label] = benchmark_util.adaptive_benchmark(
algo, timeit_fun=timeit_fun, max_time=max_time
)
_fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, figsize=(8, 3), dpi=150, sharex=True, sharey=True
)
for label, wp in results.items():
inputs = wp["arguments"]
work_compile = wp["work_compile"]
work_mean, work_std = wp["work_mean"], wp["work_std"]
if "doubling" in label:
num_repeats = jnp.diff(jnp.concatenate((jnp.ones((1,)), inputs)))
inputs = jnp.arange(1, jnp.amax(inputs) * 1)
work_compile = benchmark_util.adaptive_repeat(work_compile, num_repeats)
work_mean = benchmark_util.adaptive_repeat(work_mean, num_repeats)
work_std = benchmark_util.adaptive_repeat(work_std, num_repeats)
axis_compile.semilogy(inputs, work_compile, label=label)
axis_perform.semilogy(inputs, work_mean, label=label)
axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_perform.legend(fontsize="small")
axis_compile.legend(fontsize="small")
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid(linestyle="dotted")
axis_compile.grid(linestyle="dotted")
plt.tight_layout()
plt.show()
In [6]:
Copied!
def taylor_mode_scan() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
def taylor_mode_scan() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
In [7]:
Copied!
def taylor_mode_unroll() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
def taylor_mode_unroll() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
In [8]:
Copied!
def taylor_mode_doubling() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_doubling_unroll(num_doublings=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
def taylor_mode_doubling() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_doubling_unroll(num_doublings=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
In [9]:
Copied!
def odejet_via_jvp() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_via_jvp(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
def odejet_via_jvp() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()
@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
jetexpand = probdiffeq.jetexpand_ode_via_jvp(num=num)
tcoeffs, _ = jetexpand(vf_auto, (u0,), t=0.0)
return jnp.asarray(tcoeffs)
return estimate
In [10]:
Copied!
def _fitzhugh_nagumo():
u0 = jnp.asarray([-1.0, 1.0])
@probdiffeq.ode
def vf_probdiffeq(u, *, t):
"""FitzHugh--Nagumo model."""
del t
a, b, c = 0.2, 0.2, 3.0
du1 = c * (u[0] - u[0] ** 3 / 3 + u[1])
du2 = -(1.0 / c) * (u[0] - a - b * u[1])
return jnp.asarray([du1, du2])
return vf_probdiffeq, (u0,)
def _fitzhugh_nagumo():
u0 = jnp.asarray([-1.0, 1.0])
@probdiffeq.ode
def vf_probdiffeq(u, *, t):
"""FitzHugh--Nagumo model."""
del t
a, b, c = 0.2, 0.2, 3.0
du1 = c * (u[0] - u[0] ** 3 / 3 + u[1])
du2 = -(1.0 / c) * (u[0] - a - b * u[1])
return jnp.asarray([du1, du2])
return vf_probdiffeq, (u0,)
In [11]:
Copied!
main()
main()
Forward-mode num = 1 | elapsed = 0.00 | max_time = 0.25 num = 2 | elapsed = 0.08 | max_time = 0.25 num = 3 | elapsed = 0.02 | max_time = 0.25 num = 4 | elapsed = 0.04 | max_time = 0.25
num = 5 | elapsed = 0.07 | max_time = 0.25 num = 6 | elapsed = 0.13 | max_time = 0.25
num = 7 | elapsed = 0.34 | max_time = 0.25 Taylor-mode (scan) num = 1 | elapsed = 0.00 | max_time = 0.25 num = 2 | elapsed = 0.01 | max_time = 0.25 num = 3 | elapsed = 0.05 | max_time = 0.25 num = 4 | elapsed = 0.09 | max_time = 0.25
num = 5 | elapsed = 0.11 | max_time = 0.25 num = 6 | elapsed = 0.14 | max_time = 0.25
num = 7 | elapsed = 0.17 | max_time = 0.25
num = 8 | elapsed = 0.22 | max_time = 0.25
num = 9 | elapsed = 0.25 | max_time = 0.25 Taylor-mode (unroll) num = 1 | elapsed = 0.00 | max_time = 0.25 num = 2 | elapsed = 0.02 | max_time = 0.25 num = 3 | elapsed = 0.03 | max_time = 0.25 num = 4 | elapsed = 0.04 | max_time = 0.25 num = 5 | elapsed = 0.07 | max_time = 0.25
num = 6 | elapsed = 0.10 | max_time = 0.25 num = 7 | elapsed = 0.15 | max_time = 0.25
num = 8 | elapsed = 0.28 | max_time = 0.25 Taylor-mode (doubling) num = 1 | elapsed = 0.00 | max_time = 0.25 num = 2 | elapsed = 0.09 | max_time = 0.25
num = 3 | elapsed = 0.44 | max_time = 0.25