Skip to content

stats

Interact with IVP solutions.

For example, this module contains functionality to compute off-grid marginals, or to evaluate marginal likelihoods of observations of the solutions.

MarkovSeq ¤

Bases: NamedTuple

Markov sequence.

Source code in probdiffeq/stats.py
15
16
17
18
19
class MarkovSeq(containers.NamedTuple):
    """Markov sequence."""

    init: Any
    conditional: Any

calibrate(x, /, output_scale, *, ssm) ¤

Calibrated a posterior distribution of an IVP solution.

Source code in probdiffeq/stats.py
260
261
262
263
264
265
266
def calibrate(x, /, output_scale, *, ssm):
    """Calibrated a posterior distribution of an IVP solution."""
    if np.ndim(output_scale) > np.ndim(ssm.prototypes.output_scale()):
        output_scale = output_scale[-1]
    if isinstance(x, MarkovSeq):
        return _markov_rescale_cholesky(x, output_scale, ssm=ssm)
    return ssm.stats.rescale_cholesky(x, output_scale)

log_marginal_likelihood(u, /, *, standard_deviation, posterior, ssm) ¤

Compute the log-marginal-likelihood of observations of the IVP solution.

Note

Use log_marginal_likelihood_terminal_values to compute the log-likelihood at the terminal values.

Parameters:

Name Type Description Default
u

Observation. Expected to match the ODE's type/shape.

required
standard_deviation

Standard deviation of the observation. Expected to match 'u's Pytree structure, but every leaf must be a scalar.

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required
Source code in probdiffeq/stats.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def log_marginal_likelihood(u, /, *, standard_deviation, posterior, ssm):
    """Compute the log-marginal-likelihood of observations of the IVP solution.

    !!! note
        Use `log_marginal_likelihood_terminal_values`
        to compute the log-likelihood at the terminal values.

    Parameters
    ----------
    u
        Observation. Expected to match the ODE's type/shape.
    standard_deviation
        Standard deviation of the observation. Expected to match 'u's
        Pytree structure, but every leaf must be a scalar.
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).
    """
    [u_leaves], u_structure = tree_util.tree_flatten(u)
    [std_leaves], std_structure = tree_util.tree_flatten(standard_deviation)

    if u_structure != std_structure:
        msg = (
            f"Observation-noise tree structure {std_structure} "
            f"does not match the observation structure {u_structure}. "
        )
        raise ValueError(msg)

    qoi_flat, _ = tree_util.ravel_pytree(ssm.prototypes.qoi())
    if np.ndim(std_leaves) < 1 or np.ndim(u_leaves) != np.ndim(qoi_flat) + 1:
        msg = (
            f"Time-series solution expected. "
            f"ndim={np.ndim(u_leaves)}, shape={np.shape(u_leaves)} received."
        )
        raise ValueError(msg)

    if len(u_leaves) != len(np.asarray(std_leaves)):
        msg = (
            f"Observation-noise shape {np.shape(std_leaves)} "
            f"does not match the observation shape {np.shape(u_leaves)}. "
        )
        raise ValueError(msg)

    if not isinstance(posterior, MarkovSeq):
        msg1 = "Time-series marginal likelihoods "
        msg2 = "cannot be computed with a filtering solution."
        raise TypeError(msg1 + msg2)

    # Generate an observation-model for the QOI

    model_fun = functools.vmap(ssm.conditional.to_derivative, in_axes=(None, 0, 0))
    models = model_fun(0, u, standard_deviation)

    # Select the terminal variable
    rv = tree_util.tree_map(lambda s: s[-1, ...], posterior.init)

    # Run the reverse Kalman filter
    estimator = filter_util.kalmanfilter_with_marginal_likelihood(ssm=ssm)
    (_corrected, _num_data, logpdf), _ = filter_util.estimate_rev(
        np.zeros_like(u_leaves),
        init=rv,
        prior_transitions=posterior.conditional,
        observation_model=models,
        estimator=estimator,
    )

    # Return only the logpdf
    return logpdf

log_marginal_likelihood_terminal_values(u, /, *, standard_deviation, posterior, ssm) ¤

Compute the log-marginal-likelihood at the terminal value.

Parameters:

Name Type Description Default
u

Observation. Expected to have shape (d,) for an ODE with shape (d,).

required
standard_deviation

Standard deviation of the observation. Expected to be a scalar.

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required
Source code in probdiffeq/stats.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def log_marginal_likelihood_terminal_values(
    u, /, *, standard_deviation, posterior, ssm
):
    """Compute the log-marginal-likelihood at the terminal value.

    Parameters
    ----------
    u
        Observation. Expected to have shape (d,) for an ODE with shape (d,).
    standard_deviation
        Standard deviation of the observation. Expected to be a scalar.
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).
    """
    [u_leaves], u_structure = tree_util.tree_flatten(u)
    [std_leaves], std_structure = tree_util.tree_flatten(standard_deviation)

    if u_structure != std_structure:
        msg = (
            f"Observation-noise tree structure {std_structure} "
            f"does not match the observation structure {u_structure}. "
        )
        raise ValueError(msg)

    # Generate an observation-model for the QOI
    model = ssm.conditional.to_derivative(0, u, standard_deviation)
    rv = posterior.init if isinstance(posterior, MarkovSeq) else posterior

    data = np.zeros_like(u_leaves)  # 'u' is baked into the observation model
    _corrected, logpdf = _condition_and_logpdf(rv, data, model, ssm=ssm)
    return logpdf

markov_marginals(markov_seq: MarkovSeq, *, reverse, ssm) ¤

Extract the (time-)marginals from a Markov sequence.

Source code in probdiffeq/stats.py
80
81
82
83
84
85
86
87
88
89
90
def markov_marginals(markov_seq: MarkovSeq, *, reverse, ssm):
    """Extract the (time-)marginals from a Markov sequence."""
    _assert_filtering_solution_removed(markov_seq)

    def step(x, cond):
        extrapolated = ssm.conditional.marginalise(x, cond)
        return extrapolated, extrapolated

    init, xs = markov_seq.init, markov_seq.conditional
    _, marg = control_flow.scan(step, init=init, xs=xs, reverse=reverse)
    return marg

markov_sample(key, markov_seq: MarkovSeq, *, reverse, ssm, shape=()) ¤

Sample from a Markov sequence.

Source code in probdiffeq/stats.py
22
23
24
25
26
27
28
29
30
31
def markov_sample(key, markov_seq: MarkovSeq, *, reverse, ssm, shape=()):
    """Sample from a Markov sequence."""
    _assert_filtering_solution_removed(markov_seq)
    # A smoother samples on the grid by sampling i.i.d values
    # from the terminal RV x_N and the backward noises z_(1:N)
    # and then combining them backwards as
    # x_(n-1) = l_n @ x_n + z_n, for n=1,...,N.
    markov_seq_shape = _sample_shape(markov_seq, ssm=ssm)
    base_samples = random.normal(key, shape=shape + markov_seq_shape)
    return _transform_unit_sample(markov_seq, base_samples, reverse=reverse, ssm=ssm)

markov_select_terminal(markov_seq: MarkovSeq) -> MarkovSeq ¤

Discard all intermediate filtering solutions from a Markov sequence.

This function is useful to convert a smoothing-solution into a Markov sequence that is compatible with sampling or marginalisation.

Source code in probdiffeq/stats.py
70
71
72
73
74
75
76
77
def markov_select_terminal(markov_seq: MarkovSeq) -> MarkovSeq:
    """Discard all intermediate filtering solutions from a Markov sequence.

    This function is useful to convert a smoothing-solution into a Markov sequence
    that is compatible with sampling or marginalisation.
    """
    init = tree_util.tree_map(lambda x: x[-1, ...], markov_seq.init)
    return MarkovSeq(init, markov_seq.conditional)

offgrid_marginals_searchsorted(*, ts, solution, solver) ¤

Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

Warning

The elements in ts and the elements in the solution grid must be disjoint. Otherwise, anything can happen and the solution will be incorrect. At the moment, we do not check this.

Warning

The elements in ts must be strictly in (t0, t1). They must not lie outside the interval, and they must not coincide with the interval boundaries. At the moment, we do not check this.

Source code in probdiffeq/stats.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def offgrid_marginals_searchsorted(*, ts, solution, solver):
    """Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

    !!! warning
        The elements in ts and the elements in the solution grid must be disjoint.
        Otherwise, anything can happen and the solution will be incorrect.
        At the moment, we do not check this.

    !!! warning
        The elements in ts must be strictly in (t0, t1).
        They must not lie outside the interval, and they must not coincide
        with the interval boundaries.
        At the moment, we do not check this.
    """
    offgrid_marginals_vmap = functools.vmap(_offgrid_marginals, in_axes=(0, None, None))
    return offgrid_marginals_vmap(ts, solution, solver)