Skip to content

stats

Interact with IVP solutions.

For example, this module contains functionality to compute off-grid marginals, or to evaluate marginal likelihoods of observations of the solutions.

MarkovSeq ¤

Bases: NamedTuple

Markov sequence.

Source code in probdiffeq/stats.py
16
17
18
19
20
class MarkovSeq(containers.NamedTuple):
    """Markov sequence."""

    init: Any
    conditional: Any

calibrate(x, /, output_scale) ¤

Calibrated a posterior distribution of an IVP solution.

Source code in probdiffeq/stats.py
257
258
259
260
261
262
263
def calibrate(x, /, output_scale):
    """Calibrated a posterior distribution of an IVP solution."""
    if np.ndim(output_scale) > np.ndim(impl.prototypes.output_scale()):
        output_scale = output_scale[-1]
    if isinstance(x, MarkovSeq):
        return _markov_rescale_cholesky(x, output_scale)
    return impl.stats.rescale_cholesky(x, output_scale)

log_marginal_likelihood(u, /, *, standard_deviation, posterior) ¤

Compute the log-marginal-likelihood of observations of the IVP solution.

Note

Use log_marginal_likelihood_terminal_values to compute the log-likelihood at the terminal values.

Parameters:

Name Type Description Default
standard_deviation

Standard deviation of the observation. Expected to be have shape (n,).

required
u

Observation. Expected to have shape (n, d) for an ODE with shape (d,).

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required
Source code in probdiffeq/stats.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def log_marginal_likelihood(u, /, *, standard_deviation, posterior):
    """Compute the log-marginal-likelihood of observations of the IVP solution.

    !!! note
        Use `log_marginal_likelihood_terminal_values`
        to compute the log-likelihood at the terminal values.

    Parameters
    ----------
    standard_deviation
        Standard deviation of the observation. Expected to be have shape (n,).
    u
        Observation. Expected to have shape (n, d) for an ODE with shape (d,).
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).
    """
    # TODO: complain if it is used with a filter, not a smoother?
    # TODO: allow option for log-posterior

    if np.shape(standard_deviation) != np.shape(u)[:1]:
        msg = (
            f"Observation-noise shape {np.shape(standard_deviation)} "
            f"does not match the observation shape {np.shape(u)}. "
            f"Expected observation-noise shape: "
            f"{np.shape(u)[0],} != {np.shape(standard_deviation)}. "
        )
        raise ValueError(msg)

    if np.ndim(u) < np.ndim(impl.prototypes.qoi()) + 1:
        msg = (
            f"Time-series solution (ndim=2, shape=(n, m)) expected. "
            f"ndim={np.ndim(u)}, shape={np.shape(u)} received."
        )
        raise ValueError(msg)

    if not isinstance(posterior, MarkovSeq):
        msg1 = "Time-series marginal likelihoods "
        msg2 = "cannot be computed with a filtering solution."
        raise TypeError(msg1 + msg2)

    # Generate an observation-model for the QOI
    model_fun = functools.vmap(impl.conditional.to_derivative, in_axes=(None, 0))
    models = model_fun(0, standard_deviation)

    # Select the terminal variable
    rv = tree_util.tree_map(lambda s: s[-1, ...], posterior.init)

    # Run the reverse Kalman filter
    estimator = filter_util.kalmanfilter_with_marginal_likelihood()
    (_corrected, _num_data, logpdf), _ = filter_util.estimate_rev(
        u,
        init=rv,
        prior_transitions=posterior.conditional,
        observation_model=models,
        estimator=estimator,
    )

    # Return only the logpdf
    return logpdf

log_marginal_likelihood_terminal_values(u, /, *, standard_deviation, posterior) ¤

Compute the log-marginal-likelihood at the terminal value.

Parameters:

Name Type Description Default
u

Observation. Expected to have shape (d,) for an ODE with shape (d,).

required
standard_deviation

Standard deviation of the observation. Expected to be a scalar.

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required
Source code in probdiffeq/stats.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def log_marginal_likelihood_terminal_values(u, /, *, standard_deviation, posterior):
    """Compute the log-marginal-likelihood at the terminal value.

    Parameters
    ----------
    u
        Observation. Expected to have shape (d,) for an ODE with shape (d,).
    standard_deviation
        Standard deviation of the observation. Expected to be a scalar.
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).
    """
    if np.shape(standard_deviation) != ():
        msg = (
            f"Scalar observation noise expected. "
            f"Shape {np.shape(standard_deviation)} received."
        )
        raise ValueError(msg)

    # not valid for scalar or matrix-valued solutions
    if np.ndim(u) > np.ndim(impl.prototypes.qoi()):
        msg = (
            f"Terminal-value solution (ndim=1, shape=(n,)) expected. "
            f"ndim={np.ndim(u)}, shape={np.shape(u)} received."
        )
        raise ValueError(msg)

    # Generate an observation-model for the QOI
    model = impl.conditional.to_derivative(0, standard_deviation)
    rv = posterior.init if isinstance(posterior, MarkovSeq) else posterior

    _corrected, logpdf = _condition_and_logpdf(rv, u, model)
    return logpdf

markov_marginals(markov_seq: MarkovSeq, *, reverse) ¤

Extract the (time-)marginals from a Markov sequence.

Source code in probdiffeq/stats.py
83
84
85
86
87
88
89
90
91
92
93
def markov_marginals(markov_seq: MarkovSeq, *, reverse):
    """Extract the (time-)marginals from a Markov sequence."""
    _assert_filtering_solution_removed(markov_seq)

    def step(x, cond):
        extrapolated = impl.conditional.marginalise(x, cond)
        return extrapolated, extrapolated

    init, xs = markov_seq.init, markov_seq.conditional
    _, marg = control_flow.scan(step, init=init, xs=xs, reverse=reverse)
    return marg

markov_sample(key, markov_seq: MarkovSeq, *, shape, reverse) ¤

Sample from a Markov sequence.

Source code in probdiffeq/stats.py
23
24
25
26
27
28
29
30
31
32
def markov_sample(key, markov_seq: MarkovSeq, *, shape, reverse):
    """Sample from a Markov sequence."""
    _assert_filtering_solution_removed(markov_seq)
    # A smoother samples on the grid by sampling i.i.d values
    # from the terminal RV x_N and the backward noises z_(1:N)
    # and then combining them backwards as
    # x_(n-1) = l_n @ x_n + z_n, for n=1,...,N.
    markov_seq_shape = _sample_shape(markov_seq)
    base_samples = random.normal(key, shape=shape + markov_seq_shape)
    return _transform_unit_sample(markov_seq, base_samples, reverse=reverse)

markov_select_terminal(markov_seq: MarkovSeq) -> MarkovSeq ¤

Discard all intermediate filtering solutions from a Markov sequence.

This function is useful to convert a smoothing-solution into a Markov sequence that is compatible with sampling or marginalisation.

Source code in probdiffeq/stats.py
73
74
75
76
77
78
79
80
def markov_select_terminal(markov_seq: MarkovSeq) -> MarkovSeq:
    """Discard all intermediate filtering solutions from a Markov sequence.

    This function is useful to convert a smoothing-solution into a Markov sequence
    that is compatible with sampling or marginalisation.
    """
    init = tree_util.tree_map(lambda x: x[-1, ...], markov_seq.init)
    return MarkovSeq(init, markov_seq.conditional)

offgrid_marginals_searchsorted(*, ts, solution, solver) ¤

Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

Warning

The elements in ts and the elements in the solution grid must be disjoint. Otherwise, anything can happen and the solution will be incorrect. At the moment, we do not check this.

Warning

The elements in ts must be strictly in (t0, t1). They must not lie outside the interval, and they must not coincide with the interval boundaries. At the moment, we do not check this.

Source code in probdiffeq/stats.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def offgrid_marginals_searchsorted(*, ts, solution, solver):
    """Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

    !!! warning
        The elements in ts and the elements in the solution grid must be disjoint.
        Otherwise, anything can happen and the solution will be incorrect.
        At the moment, we do not check this.

    !!! warning
        The elements in ts must be strictly in (t0, t1).
        They must not lie outside the interval, and they must not coincide
        with the interval boundaries.
        At the moment, we do not check this.
    """
    offgrid_marginals_vmap = functools.vmap(_offgrid_marginals, in_axes=(0, None, None))
    return offgrid_marginals_vmap(ts, solution, solver)