Skip to content

taylor

Taylor-expand the solution of an initial value problem (IVP).

odejet_affine(vf: Callable, inits: Sequence[Array], /, num: int) ¤

Evaluate the Taylor series of an affine differential equation.

Compilation time

JIT-compiling this function unrolls a loop of length num.

Source code in probdiffeq/taylor.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def odejet_affine(vf: Callable, inits: Sequence[Array], /, num: int):
    """Evaluate the Taylor series of an affine differential equation.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop of length `num`.

    """
    if num == 0:
        return inits

    if not isinstance(inits[0], ArrayLike):
        _, unravel = tree_util.ravel_pytree(inits[0])
        inits_flat = [tree_util.ravel_pytree(m)[0] for m in inits]

        def vf_wrapped(*ys, **kwargs):
            ys = tree_util.tree_map(unravel, ys)
            return tree_util.ravel_pytree(vf(*ys, **kwargs))[0]

        tcoeffs = odejet_affine(vf_wrapped, inits_flat, num=num)
        return tree_util.tree_map(unravel, tcoeffs)

    fx, jvp_fn = functools.linearize(vf, *inits)

    tmp = fx
    fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)]
    return [*inits, fx, *fx_evaluations]

odejet_doubling_unroll(vf: Callable, inits: Sequence[Array], /, num_doublings: int) ¤

Combine Taylor-mode differentiation and Newton's doubling.

Warning: highly EXPERIMENTAL feature!

Support for Newton's doubling is highly experimental. There is no guarantee that it works correctly. It might be deleted tomorrow and without any deprecation policy.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def odejet_doubling_unroll(vf: Callable, inits: Sequence[Array], /, num_doublings: int):
    """Combine Taylor-mode differentiation and Newton's doubling.

    !!! warning "Warning: highly EXPERIMENTAL feature!"
        Support for Newton's doubling is highly experimental.
        There is no guarantee that it works correctly.
        It might be deleted tomorrow
        and without any deprecation policy.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    if not isinstance(inits[0], ArrayLike):
        _, unravel = tree_util.ravel_pytree(inits[0])
        inits_flat = [tree_util.ravel_pytree(m)[0] for m in inits]

        def vf_wrapped(*ys, **kwargs):
            ys = tree_util.tree_map(unravel, ys)
            return tree_util.ravel_pytree(vf(*ys, **kwargs))[0]

        tcoeffs = odejet_doubling_unroll(
            vf_wrapped, inits_flat, num_doublings=num_doublings
        )
        return tree_util.tree_map(unravel, tcoeffs)

    (u0,) = inits
    zeros = np.zeros_like(u0)

    def jet_embedded(*c, degree):
        """Call a modified jet().

        The modifications include:
        * We merge "primals" and "series" into a single set of coefficients
        * We expect and return _normalised_ Taylor coefficients.

        The reason for the latter is that the doubling-recursion
        simplifies drastically for normalised coefficients
        (compared to unnormalised coefficients).
        """
        coeffs_emb = [*c] + [zeros] * degree
        p, *s = _unnormalise(*coeffs_emb)
        p_new, s_new = functools.jet(vf, (p,), (s,))
        return _normalise(p_new, *s_new)

    taylor_coefficients = [u0]
    degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))
    for deg in degrees:
        jet_embedded_deg = tree_util.Partial(jet_embedded, degree=deg)
        fx, jvp = functools.linearize(jet_embedded_deg, *taylor_coefficients)

        # Compute the next set of coefficients.
        # TODO: can we fori_loop() this loop?
        #  the running variable (cs_padded) should have constant size
        cs = [(fx[deg - 1] / deg)]
        cs_padded = cs + [zeros] * (deg - 1)
        for i, fx_i in enumerate(fx[deg : 2 * deg]):
            # The Jacobian of the embedded jet is block-banded,
            # i.e., of the form (for j=3)
            # (A0, 0, 0; A1, A0, 0; A2, A1, A0; *, *, *; *, *, *; *, *, *)
            # Thus, by attaching zeros to the current set of coefficients
            # until the input and output shapes match, we compute
            # the convolution-like sum of matrix-vector products with
            # a single call to the JVP function.
            # Bettencourt et al. (2019;
            # "Taylor-mode autodiff for higher-order derivatives in JAX")
            # explain details.
            # i = k - deg
            linear_combination = jvp(*cs_padded)[i]
            cs_ = cs_padded[: (i + 1)]
            cs_ += [(fx_i + linear_combination) / (i + deg + 1)]
            cs_padded = cs_ + [zeros] * (deg - i - 2)

        # Store all new coefficients
        taylor_coefficients.extend(cs_padded)

    return _unnormalise(*taylor_coefficients)

odejet_padded_scan(vf: Callable, inits: Sequence[Array], /, num: int) ¤

Taylor-expand the solution of an IVP with Taylor-mode differentiation.

Other than odejet_unroll(), this function implements the loop via a scan, which comes at the price of padding the loop variable with zeros as appropriate. It is expected to compile more quickly than odejet_unroll(), but may execute more slowly.

The differences should be small. Consult the benchmarks if performance is critical.

Source code in probdiffeq/taylor.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def odejet_padded_scan(vf: Callable, inits: Sequence[Array], /, num: int):
    """Taylor-expand the solution of an IVP with Taylor-mode differentiation.

    Other than `odejet_unroll()`, this function implements the loop via a scan,
    which comes at the price of padding the loop variable with zeros as appropriate.
    It is expected to compile more quickly than `odejet_unroll()`, but may
    execute more slowly.

    The differences should be small.
    Consult the benchmarks if performance is critical.
    """
    if not isinstance(inits[0], ArrayLike):
        _, unravel = tree_util.ravel_pytree(inits[0])
        inits_flat = [tree_util.ravel_pytree(m)[0] for m in inits]

        def vf_wrapped(*ys, **kwargs):
            ys = tree_util.tree_map(unravel, ys)
            return tree_util.ravel_pytree(vf(*ys, **kwargs))[0]

        tcoeffs = odejet_padded_scan(vf_wrapped, inits_flat, num=num)
        return tree_util.tree_map(unravel, tcoeffs)

    # Number of positional arguments in f
    num_arguments = len(inits)

    # Initial Taylor series (u_0, u_1, ..., u_k)
    primals = vf(*inits)
    taylor_coeffs = [*inits, primals]

    def body(tcoeffs, _):
        # Pad the Taylor coefficients in zeros, call jet, and return the solution.
        # This works, because the $i$th output coefficient of jet()
        # is independent of the $i+j$th input coefficient
        # (see also the explanation in odejet_doubling_unroll)
        series = _subsets(tcoeffs[1:], num_arguments)  # for high-order ODEs
        p, s_new = functools.jet(vf, primals=inits, series=series)

        # The final values in s_new are nonsensical
        # (well, they are not; but we don't care about them)
        # so we remove them
        tcoeffs = [*inits, p, *s_new[:-1]]
        return tcoeffs, None

    # Pad the initial Taylor series with zeros
    num_outputs = num_arguments + num
    zeros = np.zeros_like(primals)
    taylor_coeffs = _pad_to_length(taylor_coeffs, length=num_outputs, value=zeros)

    # Early exit for num=1.
    #  Why? because zero-length scan and disable_jit() don't work together.
    if num == 1:
        return taylor_coeffs

    # Compute all coefficients with scan().
    taylor_coeffs, _ = control_flow.scan(
        body, init=taylor_coeffs, xs=None, length=num - 1
    )
    return taylor_coeffs

odejet_unroll(vf: Callable, inits: Sequence[Array], /, num: int) ¤

Taylor-expand the solution of an IVP with Taylor-mode differentiation.

Other than odejet_padded_scan(), this function does not depend on zero-padding the coefficients at the price of unrolling a loop of length num-1. It is expected to compile more slowly than odejet_padded_scan(), but execute more quickly.

The differences should be small. Consult the benchmarks if performance is critical.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def odejet_unroll(vf: Callable, inits: Sequence[Array], /, num: int):
    """Taylor-expand the solution of an IVP with Taylor-mode differentiation.

    Other than `odejet_padded_scan()`, this function does not depend on zero-padding
    the coefficients at the price of unrolling a loop of length `num-1`.
    It is expected to compile more slowly than `odejet_padded_scan()`,
    but execute more quickly.

    The differences should be small.
    Consult the benchmarks if performance is critical.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    if not isinstance(inits[0], ArrayLike):
        _, unravel = tree_util.ravel_pytree(inits[0])
        inits_flat = [tree_util.ravel_pytree(m)[0] for m in inits]

        def vf_wrapped(*ys, **kwargs):
            ys = tree_util.tree_map(unravel, ys)
            return tree_util.ravel_pytree(vf(*ys, **kwargs))[0]

        tcoeffs = odejet_unroll(vf_wrapped, inits_flat, num=num)
        return tree_util.tree_map(unravel, tcoeffs)

    # Number of positional arguments in f
    num_arguments = len(inits)

    # Initial Taylor series (u_0, u_1, ..., u_k)
    primals = vf(*inits)
    taylor_coeffs = [*inits, primals]

    for _ in range(num - 1):
        series = _subsets(taylor_coeffs[1:], num_arguments)  # for high-order ODEs
        p, s_new = functools.jet(vf, primals=inits, series=series)
        taylor_coeffs = [*inits, p, *s_new]
    return taylor_coeffs

odejet_via_jvp(vf: Callable, inits: Sequence[Array], /, num: int) ¤

Taylor-expand the solution of an IVP with recursive forward-mode differentiation.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def odejet_via_jvp(vf: Callable, inits: Sequence[Array], /, num: int):
    """Taylor-expand the solution of an IVP with recursive forward-mode differentiation.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    if not isinstance(inits[0], ArrayLike):
        _, unravel = tree_util.ravel_pytree(inits[0])
        inits_flat = [tree_util.ravel_pytree(m)[0] for m in inits]

        def vf_wrapped(*ys, **kwargs):
            ys = tree_util.tree_map(unravel, ys)
            return tree_util.ravel_pytree(vf(*ys, **kwargs))[0]

        tcoeffs = odejet_via_jvp(vf_wrapped, inits_flat, num=num)
        return tree_util.tree_map(unravel, tcoeffs)

    g_n, g_0 = vf, vf
    taylor_coeffs = [*inits, vf(*inits)]
    for _ in range(num - 1):
        g_n = _fwd_recursion_iterate(fun_n=g_n, fun_0=g_0)
        taylor_coeffs = [*taylor_coeffs, g_n(*inits)]
    return taylor_coeffs

runge_kutta_starter(dt, *, num: int, prior, ssm, atol=1e-12, rtol=1e-10) ¤

Create an estimator that uses a Runge-Kutta starter.

Source code in probdiffeq/taylor.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def runge_kutta_starter(dt, *, num: int, prior, ssm, atol=1e-12, rtol=1e-10):
    """Create an estimator that uses a Runge-Kutta starter."""

    def starter(vf, initial_values: tuple, /, t):
        # TODO: higher-order ODEs
        # TODO: allow flexible "solve" method?

        # Assertions and early exits

        if len(initial_values) > 1:
            msg = "Higher-order ODEs are not supported at the moment."
            raise ValueError(msg)

        if num == 0:
            return [*initial_values]

        if num == 1:
            return [*initial_values, vf(*initial_values, t=t)]

        # Generate data

        k = num + 1  # important: k > num
        ts = np.linspace(t, t + dt * (k - 1), num=k, endpoint=True)
        ys = ode.odeint_and_save_at(
            vf, initial_values, save_at=ts, atol=atol, rtol=rtol
        )

        # Initial condition
        scale = ssm.prototypes.output_scale()
        rv_t0 = ssm.normal.standard(num + 1, scale)
        estimator = filter_util.fixedpointsmoother_precon(ssm=ssm)
        conditional_t0 = ssm.conditional.identity(num + 1)
        init = (rv_t0, conditional_t0)

        # Discretised prior
        scale = ssm.prototypes.output_scale()
        prior_vmap = functools.vmap(prior, in_axes=(0, None))
        ibm_transitions = prior_vmap(np.diff(ts), scale)

        # Generate an observation-model for the QOI
        # (1e-7 observation noise for nuggets and for reusing existing code)
        model_fun = functools.vmap(ssm.conditional.to_derivative, in_axes=(None, 0, 0))
        std = tree_util.tree_map(lambda s: 1e-7 * np.ones((len(s),)), ys)
        models = model_fun(0, ys, std)

        zeros = np.zeros_like(models.noise.mean)

        # Run the preconditioned fixedpoint smoother
        (corrected, conditional), _ = filter_util.estimate_fwd(
            zeros,
            init=init,
            prior_transitions=ibm_transitions,
            observation_model=models,
            estimator=estimator,
        )
        initial = ssm.conditional.marginalise(corrected, conditional)
        mean = ssm.stats.mean(initial)
        return ssm.unravel(mean)

    return starter