Skip to content

solution

Interact with estimated solutions (on dense grids).

For example, this module contains functionality to compute off-grid marginals, or to evaluate marginal likelihoods of observations of the solutions.

calibrate(x, /, output_scale) ¤

Calibrated a posterior distribution of an IVP solution.

Source code in probdiffeq/solvers/solution.py
174
175
176
177
178
179
180
def calibrate(x, /, output_scale):
    """Calibrated a posterior distribution of an IVP solution."""
    if np.ndim(output_scale) > np.ndim(impl.prototypes.output_scale()):
        output_scale = output_scale[-1]
    if isinstance(x, markov.MarkovSeq):
        return markov.rescale_cholesky(x, output_scale)
    return impl.variable.rescale_cholesky(x, output_scale)

log_marginal_likelihood(u, /, *, standard_deviation, posterior) ¤

Compute the log-marginal-likelihood of observations of the IVP solution.

Parameters:

Name Type Description Default
standard_deviation

Standard deviation of the observation. Expected to be have shape (n,).

required
u

Observation. Expected to have shape (n, d) for an ODE with shape (d,).

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required

Note

Use log_marginal_likelihood_terminal_values to compute the log-likelihood at the terminal values.

Source code in probdiffeq/solvers/solution.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def log_marginal_likelihood(u, /, *, standard_deviation, posterior):
    """Compute the log-marginal-likelihood of observations of the IVP solution.

    Parameters
    ----------
    standard_deviation
        Standard deviation of the observation. Expected to be have shape (n,).
    u
        Observation. Expected to have shape (n, d) for an ODE with shape (d,).
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).

    !!! note
        Use `log_marginal_likelihood_terminal_values`
        to compute the log-likelihood at the terminal values.

    """
    # TODO: complain if it is used with a filter, not a smoother?
    # TODO: allow option for log-posterior

    if np.shape(standard_deviation) != np.shape(u)[:1]:
        msg = (
            f"Observation-noise shape {np.shape(standard_deviation)} "
            f"does not match the observation shape {np.shape(u)}. "
            f"Expected observation-noise shape: "
            f"{np.shape(u)[0],} != {np.shape(standard_deviation)}. "
        )
        raise ValueError(msg)

    if np.ndim(u) < np.ndim(impl.prototypes.qoi()) + 1:
        msg = (
            f"Time-series solution (ndim=2, shape=(n, m)) expected. "
            f"ndim={np.ndim(u)}, shape={np.shape(u)} received."
        )
        raise ValueError(msg)

    if not isinstance(posterior, markov.MarkovSeq):
        msg1 = "Time-series marginal likelihoods "
        msg2 = "cannot be computed with a filtering solution."
        raise TypeError(msg1 + msg2)

    # Generate an observation-model for the QOI
    model_fun = functools.vmap(
        impl.hidden_model.conditional_to_derivative, in_axes=(None, 0)
    )
    models = model_fun(0, standard_deviation)

    # Select the terminal variable
    rv = tree_util.tree_map(lambda s: s[-1, ...], posterior.init)

    # Run the reverse Kalman filter
    estimator = discrete.kalmanfilter_with_marginal_likelihood()
    (_corrected, _num_data, logpdf), _ = discrete.estimate_rev(
        u,
        init=rv,
        prior_transitions=posterior.conditional,
        observation_model=models,
        estimator=estimator,
    )

    # Return only the logpdf
    return logpdf

log_marginal_likelihood_terminal_values(u, /, *, standard_deviation, posterior) ¤

Compute the log-marginal-likelihood at the terminal value.

Parameters:

Name Type Description Default
u

Observation. Expected to have shape (d,) for an ODE with shape (d,).

required
standard_deviation

Standard deviation of the observation. Expected to be a scalar.

required
posterior

Posterior distribution. Expected to correspond to a solution of an ODE with shape (d,).

required
Source code in probdiffeq/solvers/solution.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def log_marginal_likelihood_terminal_values(u, /, *, standard_deviation, posterior):
    """Compute the log-marginal-likelihood at the terminal value.

    Parameters
    ----------
    u
        Observation. Expected to have shape (d,) for an ODE with shape (d,).
    standard_deviation
        Standard deviation of the observation. Expected to be a scalar.
    posterior
        Posterior distribution.
        Expected to correspond to a solution of an ODE with shape (d,).
    """
    if np.shape(standard_deviation) != ():
        msg = (
            f"Scalar observation noise expected. "
            f"Shape {np.shape(standard_deviation)} received."
        )
        raise ValueError(msg)

    # not valid for scalar or matrix-valued solutions
    if np.ndim(u) > np.ndim(impl.prototypes.qoi()):
        msg = (
            f"Terminal-value solution (ndim=1, shape=(n,)) expected. "
            f"ndim={np.ndim(u)}, shape={np.shape(u)} received."
        )
        raise ValueError(msg)

    # Generate an observation-model for the QOI
    model = impl.hidden_model.conditional_to_derivative(0, standard_deviation)
    rv = posterior.init if isinstance(posterior, markov.MarkovSeq) else posterior

    _corrected, logpdf = _condition_and_logpdf(rv, u, model)
    return logpdf

offgrid_marginals_searchsorted(*, ts, solution, solver) ¤

Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

Warning

The elements in ts and the elements in the solution grid must be disjoint. Otherwise, anything can happen and the solution will be incorrect. At the moment, we do not check this.

Warning

The elements in ts must be strictly in (t0, t1). They must not lie outside the interval, and they must not coincide with the interval boundaries. At the moment, we do not check this.

Source code in probdiffeq/solvers/solution.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def offgrid_marginals_searchsorted(*, ts, solution, solver):
    """Compute off-grid marginals on a dense grid via jax.numpy.searchsorted.

    !!! warning
        The elements in ts and the elements in the solution grid must be disjoint.
        Otherwise, anything can happen and the solution will be incorrect.
        At the moment, we do not check this.

    !!! warning
        The elements in ts must be strictly in (t0, t1).
        They must not lie outside the interval, and they must not coincide
        with the interval boundaries.
        At the moment, we do not check this.
    """
    offgrid_marginals_vmap = functools.vmap(_offgrid_marginals, in_axes=(0, None, None))
    return offgrid_marginals_vmap(ts, solution, solver)