A5. Walltime | Burgers PDE¶
In [1]:
Copied!
from collections.abc import Callable
from collections.abc import Callable
In [2]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate
import tqdm
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate
import tqdm
In [3]:
Copied!
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq.util import benchmark_util
from probdiffeq import ivpsolve, probdiffeq
from probdiffeq.util import benchmark_util
In [4]:
Copied!
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)
In [5]:
Copied!
N = 50
"""Number of spatial grid points.
Small enough to include dense methods,
but large enough to penalise their O(d^3) complexity.
"""
N = 50
"""Number of spatial grid points.
Small enough to include dense methods,
but large enough to penalise their O(d^3) complexity.
"""
Out[5]:
'Number of spatial grid points.\n\nSmall enough to include dense methods,\nbut large enough to penalise their O(d^3) complexity.\n'
In [6]:
Copied!
NU = 0.01
"""Diffusion coefficient.
The larger, the stiffer the problem.
"""
NU = 0.01
"""Diffusion coefficient.
The larger, the stiffer the problem.
"""
Out[6]:
'Diffusion coefficient.\n\nThe larger, the stiffer the problem.\n'
In [7]:
Copied!
def main(start=3.0, stop=8.0, step=1.0, repeats=1) -> None:
"""Run the script."""
jax.config.update("jax_enable_x64", True)
# Visualise the dynamics
ts, ys = solve_ivp_once()
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
_fig, ax = plt.subplots(figsize=(5, 3))
pcm = ax.pcolormesh(x, ts, ys, cmap="coolwarm", vmin=-0.5, vmax=0.5)
_fig.colorbar(pcm, ax=ax)
ax.set_title("Burgers PDE")
ax.set_xlabel("Space")
ax.set_ylabel("Time")
plt.tight_layout()
plt.show()
# Read configuration from command line
tolerances = benchmark_util.setup_tolerances(start=start, stop=stop, step=step)
timeit_fun = benchmark_util.setup_timeit(repeats=repeats)
# Assemble algorithms
algorithms = {
r"TS1($3$, dense)": solver_dense(num_derivatives=3),
r"TS1($3$, blockdiag)": solver_blockdiag(num_derivatives=3),
r"TS1($3$, matfree)": solver_matfree(num_derivatives=3),
}
# Compute a reference solution
reference = solver_scipy(method="LSODA")(1e-13)
precision_fun = benchmark_util.rmse_absolute(reference)
# Compute all work-precision diagrams
results = {}
pbar = tqdm.tqdm(algorithms.items())
for label, algo in pbar:
pbar.set_description(label)
param_to_wp = benchmark_util.workprec(
algo, precision_fun=precision_fun, timeit_fun=timeit_fun
)
results[label] = param_to_wp(tolerances)
_fig, ax = plt.subplots(figsize=(5, 3))
for i, (label, wp) in enumerate(results.items()):
ax.loglog(wp["precision"], wp["work_mean"], ".-", label=label, color=f"C{i}")
ax.set_title("Work-precision diagram")
ax.set_xlabel("Precision (absolute RMSE)")
ax.set_ylabel("Work (avg. wall time)")
ax.grid(linestyle="dotted", which="both")
ax.legend(fontsize="small")
plt.tight_layout()
plt.show()
def main(start=3.0, stop=8.0, step=1.0, repeats=1) -> None:
"""Run the script."""
jax.config.update("jax_enable_x64", True)
# Visualise the dynamics
ts, ys = solve_ivp_once()
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
_fig, ax = plt.subplots(figsize=(5, 3))
pcm = ax.pcolormesh(x, ts, ys, cmap="coolwarm", vmin=-0.5, vmax=0.5)
_fig.colorbar(pcm, ax=ax)
ax.set_title("Burgers PDE")
ax.set_xlabel("Space")
ax.set_ylabel("Time")
plt.tight_layout()
plt.show()
# Read configuration from command line
tolerances = benchmark_util.setup_tolerances(start=start, stop=stop, step=step)
timeit_fun = benchmark_util.setup_timeit(repeats=repeats)
# Assemble algorithms
algorithms = {
r"TS1($3$, dense)": solver_dense(num_derivatives=3),
r"TS1($3$, blockdiag)": solver_blockdiag(num_derivatives=3),
r"TS1($3$, matfree)": solver_matfree(num_derivatives=3),
}
# Compute a reference solution
reference = solver_scipy(method="LSODA")(1e-13)
precision_fun = benchmark_util.rmse_absolute(reference)
# Compute all work-precision diagrams
results = {}
pbar = tqdm.tqdm(algorithms.items())
for label, algo in pbar:
pbar.set_description(label)
param_to_wp = benchmark_util.workprec(
algo, precision_fun=precision_fun, timeit_fun=timeit_fun
)
results[label] = param_to_wp(tolerances)
_fig, ax = plt.subplots(figsize=(5, 3))
for i, (label, wp) in enumerate(results.items()):
ax.loglog(wp["precision"], wp["work_mean"], ".-", label=label, color=f"C{i}")
ax.set_title("Work-precision diagram")
ax.set_xlabel("Precision (absolute RMSE)")
ax.set_ylabel("Work (avg. wall time)")
ax.grid(linestyle="dotted", which="both")
ax.legend(fontsize="small")
plt.tight_layout()
plt.show()
In [8]:
Copied!
def solve_ivp_once():
"""Compute plotting values for the Burgers PDE."""
def vf_scipy(_t, u):
"""Viscous Burgers equation, zero Dirichlet BC, conservative advection."""
dx = 1.0 / N
u_bc = np.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
time_span = np.asarray([0.0, 1.0])
t_eval = np.linspace(0.0, 1.0, 200)
tol = 1e-9
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=t_eval,
atol=1e-3 * tol,
rtol=tol,
method="LSODA",
)
return solution.t, solution.y.T
def solve_ivp_once():
"""Compute plotting values for the Burgers PDE."""
def vf_scipy(_t, u):
"""Viscous Burgers equation, zero Dirichlet BC, conservative advection."""
dx = 1.0 / N
u_bc = np.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
time_span = np.asarray([0.0, 1.0])
t_eval = np.linspace(0.0, 1.0, 200)
tol = 1e-9
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=t_eval,
atol=1e-3 * tol,
rtol=tol,
method="LSODA",
)
return solution.t, solution.y.T
In [9]:
Copied!
def solver_blockdiag(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's block-diagonal routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
ssm = probdiffeq.state_space_model_blockdiag()
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
def solver_blockdiag(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's block-diagonal routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
ssm = probdiffeq.state_space_model_blockdiag()
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
In [10]:
Copied!
def solver_matfree(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's matrix-free routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
key = jax.random.PRNGKey(1)
num_ensembles = (num_derivatives + 1) * 2
ssm = probdiffeq.state_space_model_matfree(key=key, num_ensembles=num_ensembles)
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
def solver_matfree(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's matrix-free routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
key = jax.random.PRNGKey(1)
num_ensembles = (num_derivatives + 1) * 2
ssm = probdiffeq.state_space_model_matfree(key=key, num_ensembles=num_ensembles)
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
In [11]:
Copied!
def solver_dense(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's matrix-free routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
ssm = probdiffeq.state_space_model_dense()
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
def solver_dense(*, num_derivatives: int) -> Callable:
"""Construct a solver that wraps ProbDiffEq's matrix-free routines."""
@probdiffeq.ode
def vf(u, /, *, t): # noqa: ARG001
"""Viscous Burgers equation."""
dx = 1.0 / N
u_bc = jnp.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = jnp.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
t0, t1 = 0.0, 1.0
ssm = probdiffeq.state_space_model_dense()
@jax.jit
def param_to_solution(tol):
jetexpand = probdiffeq.jetexpand_ode_unroll(num=num_derivatives)
tcoeffs, _ = jetexpand(vf, (u0,), t=t0)
iwp = ssm.prior_wiener_integrated(tcoeffs)
ts1 = ssm.constraint_ode_ts1(vf)
strategy = probdiffeq.strategy_filter()
solver = probdiffeq.solver(strategy=strategy, constraint=ts1)
error = probdiffeq.error_state_std(constraint=ts1)
control = ivpsolve.control_proportional_integral()
solve_fn = ivpsolve.solve_adaptive_terminal_values(
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve_fn(iwp, t0=t0, t1=t1, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])
return param_to_solution
In [12]:
Copied!
def solver_scipy(*, method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""
def vf_scipy(_t, u):
"""Viscous Burgers equation, zero Dirichlet BC, conservative advection."""
dx = 1.0 / N
u_bc = np.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
time_span = np.asarray([0.0, 1.0])
def param_to_solution(tol):
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=time_span,
atol=1e-3 * tol,
rtol=tol,
method=method,
)
return jnp.asarray(solution.y[:, -1])
return param_to_solution
def solver_scipy(*, method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""
def vf_scipy(_t, u):
"""Viscous Burgers equation, zero Dirichlet BC, conservative advection."""
dx = 1.0 / N
u_bc = np.pad(u, 1) # zero Dirichlet ghosts
u_left = u_bc[:-2]
u_right = u_bc[2:]
flux = u_bc**2 / 2.0
fluxterm = (flux[2:] - flux[:-2]) / (2.0 * dx)
laplacian = (u_right - 2.0 * u + u_left) / dx**2
return -fluxterm + NU * laplacian
x = np.linspace(0.0, 1.0, N + 1, endpoint=True)[1:-1]
u0 = jnp.sin(3 * jnp.pi * x) ** 3 * (1 - x) ** 1.5
time_span = np.asarray([0.0, 1.0])
def param_to_solution(tol):
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=time_span,
atol=1e-3 * tol,
rtol=tol,
method=method,
)
return jnp.asarray(solution.y[:, -1])
return param_to_solution
In [13]:
Copied!
main()
main()
0%| | 0/3 [00:00<?, ?it/s]
TS1($3$, dense): 0%| | 0/3 [00:00<?, ?it/s]
TS1($3$, dense): 33%|███▎ | 1/3 [00:12<00:24, 12.07s/it]
TS1($3$, blockdiag): 33%|███▎ | 1/3 [00:12<00:24, 12.07s/it]
TS1($3$, blockdiag): 67%|██████▋ | 2/3 [00:13<00:06, 6.00s/it]
TS1($3$, matfree): 67%|██████▋ | 2/3 [00:13<00:06, 6.00s/it]
TS1($3$, matfree): 100%|██████████| 3/3 [00:21<00:00, 6.77s/it]
TS1($3$, matfree): 100%|██████████| 3/3 [00:21<00:00, 7.17s/it]