Probdiffeq Migration Guide¤
This guide helps you get started with Probdiffeq for solving ordinary differential equations (ODEs), especially if you are familiar with other probabilistic or non-probabilistic ODE solvers in Python or Julia.
Probdiffeq is a JAX library that focuses on state-space-model-based formulations of probabilistic IVP solvers. For what this means, have a look at this thesis.
Transitioning from ProbNumDiffEq.jl (Julia)¤
ProbNumDiffEq.jl is a library for probabilistic IVP solvers in Julia, similar to Probdiffeq. However, both libraries are unrelated.
- Probdiffeq is Python/JAX-based; ProbNumDiffEq is Julia-based.
- Probdiffeq provides additional solvers, dense output, and posterior sampling.
- ProbNumDiffEq handles mass-matrix problems and callbacks, which Probdiffeq does not (yet).
To translate ProbNumDiffEq.jl code to Probdiffeq code:
| ProbNumDiffEq.jl | ProbDiffEq Equivalent |
|---|---|
EK0 / EK1 |
ts0() / ts1() |
DynamicDiffusion / FixedDiffusion |
ivpsolvers.solver_dynamic() or ivpsolvers.solver_mle() |
IWP(diffusion=x^2) |
prior_wiener_integrated(output_scale=x) |
Filtering and smoothing via smooth=true/false |
Solver strategy constructions, including one for fixed-point smoothing |
Both libraries are evolving; consult the latest API documentation when in doubt.
Transitioning from ProbNum (Python, Numpy)¤
ProbNum is a general probabilistic numerics library based on Numpy. Probdiffeq specializes in IVP solvers using pure JAX, offering:
- Greater efficiency for ODE problems because of JAX (e.g. jit)
- Probdiffeq implements more mature solvers. The algorithms are generally faster (eg state-space model factorisations, improved adaptive step-size selection)
- Probdiffeq offers more solvers and somewhat richer outputs (sampling, marginal likelihoods, etc.).
Transitioning from Diffrax¤
Diffrax is a JAX-based library for differential equations. Key differences:
- Diffrax solvers are non-probabilistic; Probdiffeq solvers are probabilistic.
- Vector fields: Diffrax uses
ODETerm(); Probdiffeq uses plain functions(*ys, t). - Solver construction: Diffrax requires (
diffrax.Tsit5()); Probdiffeq constructs probabilistic state-space models.
Approximate solver mapping:
| Diffrax | ProbDiffEq Equivalent |
|---|---|
Heun(), Midpoint() |
prior_ibm(num_derivatives=1) or ts0() |
Tsit5(), Dopri5() |
Increase num_derivatives=4 |
Dopri8() |
Increase num_derivatives=5-7; ts1() recommended but not required |
Kvaerno3()–Kvaerno5() |
Use num_derivatives=2-4 with ts1() correction |
| Other methods | Work in progress |
General differences from conventional ODE solvers (e.g., SciPy, jax.odeint)¤
- Solutions are posterior distributions instead of point estimates, enabling uncertainty quantification and more sophisticated models (eg easy switch to second-order problems).
- Solver modes are explicit:
simulate_terminal_values(),solve_adaptive_save_every_step(),solve_adaptive_save_at()instead of a one-size-fits-allsolve()method