Skip to content

Coming from other ODE solver libraries?¤

This guide helps you get started with Probdiffeq for solving ordinary differential equations (ODEs), especially if you are familiar with other probabilistic or non-probabilistic ODE solvers in Python or Julia.

Probdiffeq is a JAX library that focuses on state-space-model-based formulations of probabilistic IVP solvers. For what this means, have a look at this thesis.

Probabilistic ODE solvers in a nutshell: Unlike traditional solvers that return a single point estimate of the solution, probabilistic solvers return a posterior distribution. This built-in uncertainty quantification reflects the numerical error (and other modelling choices), and helps you make better decisions during the simulation and in downstream tasks, for example, during adaptive time-stepping, parameter estimation, or in physics-informed machine learning applications.

From traditional (non-probabilistic) ODE solvers¤

If you're coming from traditional ODE solvers like SciPy's integrate.solve_ivp, JAX's jax.experimental.odeint, or Diffrax, you'll notice some fundamental differences:

Key differences:

  • Solutions as distributions: Probdiffeq returns posterior distributions instead of point estimates. You automatically get uncertainty quantification, which you can use for sensitivity analysis, model selection, or downstream decision-making.
  • Fine-grained control: Probdiffeq lets you customise the probabilistic model (prior distribution, calibration method, linearization order), giving you more control over solver behaviour. Since the modelling matters, everyone has to build their own custom solvers, and default behaviour is rare.
  • Explicit solver modes: Instead of a single solve() function, Probdiffeq offers specialised functions for targeting terminal values, checkpoints, or fixed grids. This is not just easier to maintain, but also enables better performance by easier code optimisation and specialised default parameters (e.g. whether or not timesteps should be clipped before checkpoints).

Mapping from Diffrax methods: If you're switching from Diffrax, here's how to achieve similar accuracy levels by adjusting Taylor coefficients and linearization order:

Diffrax method ProbDiffEq approach
Heun(), Midpoint() Use 2 Taylor coefficients with zeroth-order linearization
Tsit5(), Dopri5() Use 5 Taylor coefficients with zeroth-order linearization
Dopri8() Use 8 Taylor coefficients with zeroth-order linearization
Kvaerno3(), Kvaerno5() Use 2 to 5 Taylor coefficients with first-order linearization

Tidbit: Probabilistic solvers based on the once-integrated Wiener/OU processes are closely related to (different versions of) the trapezoidal rule (Schober et al., 2019; Bosch et al., 2023). Higher-order methods connect to more general linear multistep methods (Schober et al., 2019).

  • Michael Schober, Simo Särkkä & Philipp Hennig (2019). A probabilistic model for the numerical solution of initial value problems. Statistics and Computing, 29(1), 99–122.

  • Bosch, Nathanael, Philipp Hennig, and Filip Tronarp. "Probabilistic exponential integrators." Advances in Neural Information Processing Systems 36 (2023): 40450-40467.

Note: Probdiffeq is not a drop-in replacement for these solvers; the probabilistic approach is fundamentally different. However, you can match performance and accuracy levels by tuning the solver configuration (see the examples in the documentation).

From other probabilistic ODE solvers¤

If you're familiar with other probabilistic solver libraries, here are the comparisons:

From ProbNum (Python, Numpy): ProbNum is a general-purpose probabilistic numerics library, while Probdiffeq specialises in ODE solving with pure JAX. Advantages of Probdiffeq:

  • Greater efficiency due to JAX's JIT compilation and autodiff
  • More mature ODE algorithms (state-space factorisations, improved adaptive time-stepping)
  • Richer outputs (sampling, marginal likelihoods, marginal-likelihood losses, etc.)

From ProbNumDiffEq.jl (Julia): ProbNumDiffEq.jl is a Julia equivalent of Probdiffeq (though the libraries are unrelated), with similar features but slightly different APIs. Here's how to translate:

ProbNumDiffEq.jl concept ProbDiffEq concept
EK0 / EK1 constraint_ode_ts0() / constraint_ode_ts1()
DynamicDiffusion / FixedDiffusion solver_dynamic() / solver_mle()
IWP(diffusion=x^2) prior_wiener_integrated(output_scale=x)
smooth=true/false strategy_filter() / strategy_smoother_fixedpoint() / strategy_smoother_fixedinterval()

Both libraries are actively evolving; consult their latest API documentation if you're unsure about equivalences.