Skip to content

taylor

Taylor-expand the solution of an initial value problem (IVP).

odejet_affine(vf: Callable, initial_values: tuple[Array, ...], /, num: int) ¤

Evaluate the Taylor series of an affine differential equation.

Compilation time

JIT-compiling this function unrolls a loop of length num.

Source code in probdiffeq/taylor.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def odejet_affine(vf: Callable, initial_values: tuple[Array, ...], /, num: int):
    """Evaluate the Taylor series of an affine differential equation.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop of length `num`.

    """
    if num == 0:
        return initial_values

    fx, jvp_fn = functools.linearize(vf, *initial_values)

    tmp = fx
    fx_evaluations = [tmp := jvp_fn(tmp) for _ in range(num - 1)]
    return [*initial_values, fx, *fx_evaluations]

odejet_doubling_unroll(vf: Callable, inits: tuple[Array, ...], /, num_doublings: int) ¤

Combine Taylor-mode differentiation and Newton's doubling.

Warning: highly EXPERIMENTAL feature!

Support for Newton's doubling is highly experimental. There is no guarantee that it works correctly. It might be deleted tomorrow and without any deprecation policy.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def odejet_doubling_unroll(
    vf: Callable, inits: tuple[Array, ...], /, num_doublings: int
):
    """Combine Taylor-mode differentiation and Newton's doubling.

    !!! warning "Warning: highly EXPERIMENTAL feature!"
        Support for Newton's doubling is highly experimental.
        There is no guarantee that it works correctly.
        It might be deleted tomorrow
        and without any deprecation policy.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    (u0,) = inits
    zeros = np.zeros_like(u0)

    def jet_embedded(*c, degree):
        """Call a modified jet().

        The modifications include:
        * We merge "primals" and "series" into a single set of coefficients
        * We expect and return _normalised_ Taylor coefficients.

        The reason for the latter is that the doubling-recursion
        simplifies drastically for normalised coefficients
        (compared to unnormalised coefficients).
        """
        coeffs_emb = [*c] + [zeros] * degree
        p, *s = _unnormalise(*coeffs_emb)
        p_new, s_new = functools.jet(vf, (p,), (s,))
        return _normalise(p_new, *s_new)

    taylor_coefficients = [u0]
    degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))
    for deg in degrees:
        jet_embedded_deg = tree_util.Partial(jet_embedded, degree=deg)
        fx, jvp = functools.linearize(jet_embedded_deg, *taylor_coefficients)

        # Compute the next set of coefficients.
        # TODO: can we fori_loop() this loop?
        #  the running variable (cs_padded) should have constant size
        cs = [(fx[deg - 1] / deg)]
        cs_padded = cs + [zeros] * (deg - 1)
        for i, fx_i in enumerate(fx[deg : 2 * deg]):
            # The Jacobian of the embedded jet is block-banded,
            # i.e., of the form (for j=3)
            # (A0, 0, 0; A1, A0, 0; A2, A1, A0; *, *, *; *, *, *; *, *, *)
            # Thus, by attaching zeros to the current set of coefficients
            # until the input and output shapes match, we compute
            # the convolution-like sum of matrix-vector products with
            # a single call to the JVP function.
            # Bettencourt et al. (2019;
            # "Taylor-mode autodiff for higher-order derivatives in JAX")
            # explain details.
            # i = k - deg
            linear_combination = jvp(*cs_padded)[i]
            cs_ = cs_padded[: (i + 1)]
            cs_ += [(fx_i + linear_combination) / (i + deg + 1)]
            cs_padded = cs_ + [zeros] * (deg - i - 2)

        # Store all new coefficients
        taylor_coefficients.extend(cs_padded)

    return _unnormalise(*taylor_coefficients)

odejet_padded_scan(vf: Callable, inits: tuple[Array, ...], /, num: int) ¤

Taylor-expand the solution of an IVP with Taylor-mode differentiation.

Other than odejet_unroll(), this function implements the loop via a scan, which comes at the price of padding the loop variable with zeros as appropriate. It is expected to compile more quickly than odejet_unroll(), but may execute more slowly.

The differences should be small. Consult the benchmarks if performance is critical.

Source code in probdiffeq/taylor.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def odejet_padded_scan(vf: Callable, inits: tuple[Array, ...], /, num: int):
    """Taylor-expand the solution of an IVP with Taylor-mode differentiation.

    Other than `odejet_unroll()`, this function implements the loop via a scan,
    which comes at the price of padding the loop variable with zeros as appropriate.
    It is expected to compile more quickly than `odejet_unroll()`, but may
    execute more slowly.

    The differences should be small.
    Consult the benchmarks if performance is critical.
    """
    # Number of positional arguments in f
    num_arguments = len(inits)

    # Initial Taylor series (u_0, u_1, ..., u_k)
    primals = vf(*inits)
    taylor_coeffs = [*inits, primals]

    def body(tcoeffs, _):
        # Pad the Taylor coefficients in zeros, call jet, and return the solution.
        # This works, because the $i$th output coefficient of jet()
        # is independent of the $i+j$th input coefficient
        # (see also the explanation in odejet_doubling_unroll)
        series = _subsets(tcoeffs[1:], num_arguments)  # for high-order ODEs
        p, s_new = functools.jet(vf, primals=inits, series=series)

        # The final values in s_new are nonsensical
        # (well, they are not; but we don't care about them)
        # so we remove them
        tcoeffs = [*inits, p, *s_new[:-1]]
        return tcoeffs, None

    # Pad the initial Taylor series with zeros
    num_outputs = num_arguments + num
    zeros = np.zeros_like(primals)
    taylor_coeffs = _pad_to_length(taylor_coeffs, length=num_outputs, value=zeros)

    # Early exit for num=1.
    #  Why? because zero-length scan and disable_jit() don't work together.
    if num == 1:
        return taylor_coeffs

    # Compute all coefficients with scan().
    taylor_coeffs, _ = control_flow.scan(
        body, init=taylor_coeffs, xs=None, length=num - 1
    )
    return taylor_coeffs

odejet_unroll(vf: Callable, inits: tuple[Array, ...], /, num: int) ¤

Taylor-expand the solution of an IVP with Taylor-mode differentiation.

Other than odejet_padded_scan(), this function does not depend on zero-padding the coefficients at the price of unrolling a loop of length num-1. It is expected to compile more slowly than odejet_padded_scan(), but execute more quickly.

The differences should be small. Consult the benchmarks if performance is critical.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def odejet_unroll(vf: Callable, inits: tuple[Array, ...], /, num: int):
    """Taylor-expand the solution of an IVP with Taylor-mode differentiation.

    Other than `odejet_padded_scan()`, this function does not depend on zero-padding
    the coefficients at the price of unrolling a loop of length `num-1`.
    It is expected to compile more slowly than `odejet_padded_scan()`,
    but execute more quickly.

    The differences should be small.
    Consult the benchmarks if performance is critical.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    # Number of positional arguments in f
    num_arguments = len(inits)

    # Initial Taylor series (u_0, u_1, ..., u_k)
    primals = vf(*inits)
    taylor_coeffs = [*inits, primals]

    for _ in range(num - 1):
        series = _subsets(taylor_coeffs[1:], num_arguments)  # for high-order ODEs
        p, s_new = functools.jet(vf, primals=inits, series=series)
        taylor_coeffs = [*inits, p, *s_new]
    return taylor_coeffs

odejet_via_jvp(vf: Callable, inits: tuple[Array, ...], /, num: int) ¤

Taylor-expand the solution of an IVP with recursive forward-mode differentiation.

Compilation time

JIT-compiling this function unrolls a loop.

Source code in probdiffeq/taylor.py
175
176
177
178
179
180
181
182
183
184
185
186
187
def odejet_via_jvp(vf: Callable, inits: tuple[Array, ...], /, num: int):
    """Taylor-expand the solution of an IVP with recursive forward-mode differentiation.

    !!! warning "Compilation time"
        JIT-compiling this function unrolls a loop.

    """
    g_n, g_0 = vf, vf
    taylor_coeffs = [*inits, vf(*inits)]
    for _ in range(num - 1):
        g_n = _fwd_recursion_iterate(fun_n=g_n, fun_0=g_0)
        taylor_coeffs = [*taylor_coeffs, g_n(*inits)]
    return taylor_coeffs

runge_kutta_starter(dt, *, atol=1e-12, rtol=1e-10) ¤

Create an estimator that uses a Runge-Kutta starter.

Source code in probdiffeq/taylor.py
10
11
12
13
def runge_kutta_starter(dt, *, atol=1e-12, rtol=1e-10):
    """Create an estimator that uses a Runge-Kutta starter."""
    # If the accuracy of the initialisation is bad, play around with dt.
    return functools.partial(_runge_kutta_starter, dt0=dt, atol=atol, rtol=rtol)