Taylor-series: Pleiades¶
The Pleiades problem is a common non-stiff differential equation.
In [1]:
Copied!
"""Benchmark all Taylor-series estimators on the Pleiades problem."""
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
from probdiffeq.util.doc_util import notebook
jax.config.update("jax_platform_name", "cpu")
"""Benchmark all Taylor-series estimators on the Pleiades problem."""
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
from probdiffeq.util.doc_util import notebook
jax.config.update("jax_platform_name", "cpu")
In [2]:
Copied!
def load_results():
"""Load the results from a file."""
return jnp.load("./results.npy", allow_pickle=True)[()]
def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "doubling" in label.lower():
return {"color": "C3", "linestyle": "dotted", "label": label}
if "unroll" in label.lower():
return {"color": "C2", "linestyle": "dashdot", "label": label}
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid", "label": label}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed", "label": label}
msg = f"Label {label} unknown."
raise ValueError(msg)
def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
for label, wp in results.items():
style = choose_style(label)
inputs = wp["arguments"]
work_mean = wp["work_compile"]
axis_compile.semilogy(inputs, work_mean, **style)
work_mean, work_std = (wp["work_mean"], wp["work_std"])
range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, **style)
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style)
return axis_compile, axis_perform
def load_results():
"""Load the results from a file."""
return jnp.load("./results.npy", allow_pickle=True)[()]
def choose_style(label):
"""Choose a plotting style for a given algorithm."""
if "doubling" in label.lower():
return {"color": "C3", "linestyle": "dotted", "label": label}
if "unroll" in label.lower():
return {"color": "C2", "linestyle": "dashdot", "label": label}
if "taylor" in label.lower():
return {"color": "C0", "linestyle": "solid", "label": label}
if "forward" in label.lower():
return {"color": "C1", "linestyle": "dashed", "label": label}
msg = f"Label {label} unknown."
raise ValueError(msg)
def plot_results(axis_compile, axis_perform, results):
"""Plot the results."""
for label, wp in results.items():
style = choose_style(label)
inputs = wp["arguments"]
work_mean = wp["work_compile"]
axis_compile.semilogy(inputs, work_mean, **style)
work_mean, work_std = (wp["work_mean"], wp["work_std"])
range_lower, range_upper = work_mean - work_std, work_mean + work_std
axis_perform.semilogy(inputs, work_mean, **style)
axis_perform.fill_between(inputs, range_lower, range_upper, alpha=0.3, **style)
return axis_compile, axis_perform
In [3]:
Copied!
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
plt.rcParams.update(notebook.plot_style())
plt.rcParams.update(notebook.plot_sizes())
In [4]:
Copied!
fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, sharex=True, figsize=(8, 3)
)
results = load_results()
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)
axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_compile.legend()
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid()
axis_compile.grid()
axis_perform.set_yticks((1e-5, 1e-4))
plt.show()
fig, (axis_perform, axis_compile) = plt.subplots(
ncols=2, dpi=150, sharex=True, figsize=(8, 3)
)
results = load_results()
axis_compile, axis_perform = plot_results(axis_compile, axis_perform, results)
axis_compile.set_title("Compilation time")
axis_perform.set_title("Evaluation time")
axis_compile.legend()
axis_compile.set_xlabel("Number of Derivatives")
axis_perform.set_xlabel("Number of Derivatives")
axis_perform.set_ylabel("Wall time (sec)")
axis_perform.grid()
axis_compile.grid()
axis_perform.set_yticks((1e-5, 1e-4))
plt.show()
In [ ]:
Copied!