D0. Convergence rates | Lotka-Volterra¶
In [1]:
Copied!
import statistics
from collections.abc import Callable
import statistics
from collections.abc import Callable
In [2]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate
import tqdm
import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate
import tqdm
In [3]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq.util import benchmark_util
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq.util import benchmark_util
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
def main() -> None:
"""Run the script."""
# High order solvers need double precision
jax.config.update("jax_enable_x64", True)
# Assemble algorithms
algorithms = {
r"TS1(1)": solver_probdiffeq(1),
r"TS1(2)": solver_probdiffeq(2),
r"TS1(3)": solver_probdiffeq(3),
r"TS1(4)": solver_probdiffeq(4),
r"TS1(5)": solver_probdiffeq(5),
r"TS1(6)": solver_probdiffeq(6),
r"TS1(7)": solver_probdiffeq(7),
r"TS1(8)": solver_probdiffeq(8),
r"TS1(9)": solver_probdiffeq(9),
r"TS1(10)": solver_probdiffeq(10),
r"TS1(11)": solver_probdiffeq(11),
r"TS1(12)": solver_probdiffeq(12),
r"TS1(13)": solver_probdiffeq(13),
r"TS1(14)": solver_probdiffeq(14),
r"TS1(15)": solver_probdiffeq(15),
r"TS1(16)": solver_probdiffeq(16),
r"TS1(17)": solver_probdiffeq(17),
}
# Set up the benchmark (compute a reference etc.)
reference = solver_scipy(method="LSODA")(1e-12)
tolerances = benchmark_util.setup_tolerances(start=2, stop=8, step=0.5)
precision_fun = benchmark_util.rmse_relative(reference)
timeit_fun = benchmark_util.setup_timeit(repeats=1)
# Compute all work-precision diagrams
results = {}
for label, algo in tqdm.tqdm(algorithms.items()):
param_to_wp = workprec(algo, precision_fun=precision_fun, timeit_fun=timeit_fun)
results[label] = param_to_wp(tolerances)
layout = [["values", "trends"]]
_fig, ax = plt.subplot_mosaic(
layout,
figsize=(8, 3),
constrained_layout=True,
dpi=120,
sharex=True,
sharey=True,
)
for i, (keys, values) in enumerate(results.items()):
cmap = mpl.colormaps["managua"]
i_clipped = i / len(results.keys())
color = mpl.colors.to_hex(cmap(i_clipped))
# Smooth curves
x, y = smooth(values["work_num_steps"], values["precision"])
(x_lin, y_lin), (scale, _) = linear_trend(
values["work_num_steps"], values["precision"]
)
# All curves start at (1, 1)
ax["values"].loglog(x / x.min(), y / y.max(), color=color, label=keys)
ax["trends"].loglog(
x_lin / x_lin.min(),
y_lin / y_lin.max(),
color=color,
label=f"Rate: {scale:.1f}",
)
ax["values"].set_title("Values (slightly smoothed)")
ax["trends"].set_title("Decay rates (linear fit)")
ax["values"].set_ylabel("RMSE (normalised)")
for a in [ax["values"], ax["trends"]]:
a.grid(which="minor", linestyle="dotted")
a.set_xlabel("Num steps (normalised)")
a.legend(fontsize="x-small", ncols=2)
plt.show()
def main() -> None:
"""Run the script."""
# High order solvers need double precision
jax.config.update("jax_enable_x64", True)
# Assemble algorithms
algorithms = {
r"TS1(1)": solver_probdiffeq(1),
r"TS1(2)": solver_probdiffeq(2),
r"TS1(3)": solver_probdiffeq(3),
r"TS1(4)": solver_probdiffeq(4),
r"TS1(5)": solver_probdiffeq(5),
r"TS1(6)": solver_probdiffeq(6),
r"TS1(7)": solver_probdiffeq(7),
r"TS1(8)": solver_probdiffeq(8),
r"TS1(9)": solver_probdiffeq(9),
r"TS1(10)": solver_probdiffeq(10),
r"TS1(11)": solver_probdiffeq(11),
r"TS1(12)": solver_probdiffeq(12),
r"TS1(13)": solver_probdiffeq(13),
r"TS1(14)": solver_probdiffeq(14),
r"TS1(15)": solver_probdiffeq(15),
r"TS1(16)": solver_probdiffeq(16),
r"TS1(17)": solver_probdiffeq(17),
}
# Set up the benchmark (compute a reference etc.)
reference = solver_scipy(method="LSODA")(1e-12)
tolerances = benchmark_util.setup_tolerances(start=2, stop=8, step=0.5)
precision_fun = benchmark_util.rmse_relative(reference)
timeit_fun = benchmark_util.setup_timeit(repeats=1)
# Compute all work-precision diagrams
results = {}
for label, algo in tqdm.tqdm(algorithms.items()):
param_to_wp = workprec(algo, precision_fun=precision_fun, timeit_fun=timeit_fun)
results[label] = param_to_wp(tolerances)
layout = [["values", "trends"]]
_fig, ax = plt.subplot_mosaic(
layout,
figsize=(8, 3),
constrained_layout=True,
dpi=120,
sharex=True,
sharey=True,
)
for i, (keys, values) in enumerate(results.items()):
cmap = mpl.colormaps["managua"]
i_clipped = i / len(results.keys())
color = mpl.colors.to_hex(cmap(i_clipped))
# Smooth curves
x, y = smooth(values["work_num_steps"], values["precision"])
(x_lin, y_lin), (scale, _) = linear_trend(
values["work_num_steps"], values["precision"]
)
# All curves start at (1, 1)
ax["values"].loglog(x / x.min(), y / y.max(), color=color, label=keys)
ax["trends"].loglog(
x_lin / x_lin.min(),
y_lin / y_lin.max(),
color=color,
label=f"Rate: {scale:.1f}",
)
ax["values"].set_title("Values (slightly smoothed)")
ax["trends"].set_title("Decay rates (linear fit)")
ax["values"].set_ylabel("RMSE (normalised)")
for a in [ax["values"], ax["trends"]]:
a.grid(which="minor", linestyle="dotted")
a.set_xlabel("Num steps (normalised)")
a.legend(fontsize="x-small", ncols=2)
plt.show()
In [6]:
Copied!
def solver_scipy(*, method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""
def vf_scipy(_t, y):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return np.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
time_span = np.asarray([0.0, 50.0])
def param_to_solution(tol):
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=time_span,
atol=1e-3 * tol,
rtol=tol,
method=method,
)
return jnp.asarray(solution.y[..., -1])
return param_to_solution
def solver_scipy(*, method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""
def vf_scipy(_t, y):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return np.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
time_span = np.asarray([0.0, 50.0])
def param_to_solution(tol):
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=time_span,
atol=1e-3 * tol,
rtol=tol,
method=method,
)
return jnp.asarray(solution.y[..., -1])
return param_to_solution
In [7]:
Copied!
def solver_probdiffeq(num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's solution routines."""
@probdiffeq.ode
def vf_probdiffeq(y, /, *, t): # noqa: ARG001
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return jnp.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
t0, t1 = (0.0, 50.0)
@jax.jit
def param_to_solution(tol):
# Do inside the function so we jit the Taylor code
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=num_derivatives)
tcoeffs, _ = jetexpand(vf_probdiffeq, (u0,), t=t0)
# Build a solver
ssm = probdiffeq.state_space_model_dense()
iwp = ssm.prior_wiener_integrated(tcoeffs)
strategy = probdiffeq.strategy_filter()
ts = ssm.constraint_ode_ts1(vf_probdiffeq)
solver = probdiffeq.solver(strategy=strategy, constraint=ts)
error = probdiffeq.error_residual_std(constraint=ts)
control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control
)
# Solve
dt0 = ivpsolve.dt0(vf_probdiffeq, (u0,), t=t0)
solution = solve(iwp, t0=t0, t1=t1, dt0=dt0, atol=1e-2 * tol, rtol=tol)
# Return the terminal value
return solution.u.mean[0], solution.num_steps
return param_to_solution
def solver_probdiffeq(num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's solution routines."""
@probdiffeq.ode
def vf_probdiffeq(y, /, *, t): # noqa: ARG001
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return jnp.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
t0, t1 = (0.0, 50.0)
@jax.jit
def param_to_solution(tol):
# Do inside the function so we jit the Taylor code
jetexpand = probdiffeq.jetexpand_ode_padded_scan(num=num_derivatives)
tcoeffs, _ = jetexpand(vf_probdiffeq, (u0,), t=t0)
# Build a solver
ssm = probdiffeq.state_space_model_dense()
iwp = ssm.prior_wiener_integrated(tcoeffs)
strategy = probdiffeq.strategy_filter()
ts = ssm.constraint_ode_ts1(vf_probdiffeq)
solver = probdiffeq.solver(strategy=strategy, constraint=ts)
error = probdiffeq.error_residual_std(constraint=ts)
control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control
)
# Solve
dt0 = ivpsolve.dt0(vf_probdiffeq, (u0,), t=t0)
solution = solve(iwp, t0=t0, t1=t1, dt0=dt0, atol=1e-2 * tol, rtol=tol)
# Return the terminal value
return solution.u.mean[0], solution.num_steps
return param_to_solution
In [8]:
Copied!
def workprec(fun, *, precision_fun: Callable, timeit_fun: Callable) -> Callable:
"""Turn a parameter-to-solution function into parameter-to-workprecision."""
def parameter_list_to_workprecision(list_of_args, /):
works_num_steps = []
works_min = []
works_mean = []
works_std = []
precisions = []
for arg in list_of_args:
_x, num_steps = fun(arg)
precision = precision_fun(fun(arg)[0].block_until_ready())
times = timeit_fun(lambda: fun(arg)[0].block_until_ready()) # noqa: B023
precisions.append(precision)
works_num_steps.append(num_steps)
works_min.append(min(times))
works_mean.append(statistics.mean(times))
if len(times) > 1:
works_std.append(statistics.stdev(times))
return {
"work_mean": jnp.asarray(works_mean),
"work_min": jnp.asarray(works_min),
"work_num_steps": jnp.asarray(works_num_steps),
"work_std": jnp.asarray(works_std),
"precision": jnp.asarray(precisions),
}
return parameter_list_to_workprecision
def workprec(fun, *, precision_fun: Callable, timeit_fun: Callable) -> Callable:
"""Turn a parameter-to-solution function into parameter-to-workprecision."""
def parameter_list_to_workprecision(list_of_args, /):
works_num_steps = []
works_min = []
works_mean = []
works_std = []
precisions = []
for arg in list_of_args:
_x, num_steps = fun(arg)
precision = precision_fun(fun(arg)[0].block_until_ready())
times = timeit_fun(lambda: fun(arg)[0].block_until_ready()) # noqa: B023
precisions.append(precision)
works_num_steps.append(num_steps)
works_min.append(min(times))
works_mean.append(statistics.mean(times))
if len(times) > 1:
works_std.append(statistics.stdev(times))
return {
"work_mean": jnp.asarray(works_mean),
"work_min": jnp.asarray(works_min),
"work_num_steps": jnp.asarray(works_num_steps),
"work_std": jnp.asarray(works_std),
"precision": jnp.asarray(precisions),
}
return parameter_list_to_workprecision
In [9]:
Copied!
def smooth(x, y, window=2):
"""Smooth a set of data points to improve visualisation."""
kernel = jnp.ones((window,)) / window
x = jnp.convolve(x, kernel, mode="valid")
y = jnp.convolve(y, kernel, mode="valid")
return x, y
def smooth(x, y, window=2):
"""Smooth a set of data points to improve visualisation."""
kernel = jnp.ones((window,)) / window
x = jnp.convolve(x, kernel, mode="valid")
y = jnp.convolve(y, kernel, mode="valid")
return x, y
In [10]:
Copied!
def linear_trend(x, y):
"""Fit a linear curve through the logarithms of x and y."""
x = jnp.log10(x)
y = jnp.log10(y)
scale, bias = jnp.polyfit(x, y, 1)
return (10 ** (x), 10 ** (scale * x + bias)), (scale, bias)
def linear_trend(x, y):
"""Fit a linear curve through the logarithms of x and y."""
x = jnp.log10(x)
y = jnp.log10(y)
scale, bias = jnp.polyfit(x, y, 1)
return (10 ** (x), 10 ** (scale * x + bias)), (scale, bias)
In [11]:
Copied!
main()
main()
0%| | 0/17 [00:00<?, ?it/s]
6%|▌ | 1/17 [00:10<02:53, 10.86s/it]
12%|█▏ | 2/17 [00:15<01:44, 6.95s/it]
18%|█▊ | 3/17 [00:17<01:07, 4.79s/it]
24%|██▎ | 4/17 [00:19<00:48, 3.73s/it]
29%|██▉ | 5/17 [00:21<00:36, 3.06s/it]
35%|███▌ | 6/17 [00:23<00:29, 2.67s/it]
41%|████ | 7/17 [00:25<00:23, 2.40s/it]
47%|████▋ | 8/17 [00:26<00:20, 2.24s/it]
53%|█████▎ | 9/17 [00:29<00:17, 2.21s/it]
59%|█████▉ | 10/17 [00:31<00:15, 2.21s/it]
65%|██████▍ | 11/17 [00:33<00:13, 2.23s/it]
71%|███████ | 12/17 [00:36<00:11, 2.31s/it]
76%|███████▋ | 13/17 [00:39<00:10, 2.54s/it]
82%|████████▏ | 14/17 [00:42<00:07, 2.65s/it]
88%|████████▊ | 15/17 [00:45<00:05, 2.76s/it]
94%|█████████▍| 16/17 [00:48<00:03, 3.02s/it]
100%|██████████| 17/17 [00:52<00:00, 3.40s/it]
100%|██████████| 17/17 [00:52<00:00, 3.11s/it]