An easy example¶
Let's have a look at an easy example.
"""Solve the logistic equation."""
import jax
import jax.config
import jax.numpy as jnp
from probdiffeq import adaptive, ivpsolve, timestep
from probdiffeq.impl import impl
from probdiffeq.solvers import uncalibrated
from probdiffeq.solvers.strategies import smoothers
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.taylor import autodiff
jax.config.update("jax_platform_name", "cpu")
Create a problem:
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the vector field."""
return 0.5 * y * (1 - y)
u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 1.0
ProbDiffEq contains three levels of implementations:
Low: Implementations of random-variable-arithmetic (marginalisation, conditioning, etc.)
Medium: Probabilistic IVP solver components (this is what you're here for.)
High: ODE-solving routines.
There are several random-variable implementations (read: state-space model factorisations) which model different correlations between variables. All factorisations can be used interchangeably, but they have different speed, stability, and uncertainty-quantification properties. Since the chosen implementation powers almost everything, we choose one (and only one) of them, assign it to a global variable, and call it the "impl(ementation)".
impl.select("dense", ode_shape=(1,))
# But don't worry, this configuration does not make the library any less light-weight.
# It merely affects the shapes of the arrays
# describing means and covariances of Gaussian
# random variables, and assigns functions that know how to manipulate those parameters.
#
Configuring a probabilistic IVP solver is a little more involved than configuring your favourite Runge-Kutta method: we must choose a prior distribution and a correction scheme, then we put them together as a filter or smoother, wrap everything into a solver, and (finally) make the solver adaptive.
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts1(ode_order=1)
strategy = smoothers.smoother_adaptive(ibm, ts0)
solver = uncalibrated.solver(strategy)
adaptive_solver = adaptive.adaptive(solver)
Why so many layers?
- Prior distributions incorporate prior knowledge; better prior knowledge should `improve the simulation'' (which might mean something different for different applications, hence the quotation marks)
- Different correction schemes imply different stability concerns
- Filters and smoothers are optimised estimators for either forward-only or time-series estimation
- Calibration schemes affect the behaviour of the solver
- Not all solution routines expect adaptive solvers.
The granularity of construction a solver is an asset, not a drawback.
Finally, we must prepare one last component before we can solve the differential equation:
The probabilistic IVP solvers in ProbDiffEq implement state-space-model-based IVP solvers; this means that as an initial condition, we must provide a data structure that represents the initial state in this model. For all current solvers, this amounts to computing a $\nu$-th order Taylor approximation of the IVP solution and to wrapping this approximation into a state-space-model variable.
Use the following functions:
tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)
output_scale = 1.0 # or any other value with the same shape
init = solver.initial_condition(tcoeffs, output_scale)
Other software packages that implement probabilistic IVP solvers do a lot of this work implicitly; probdiffeq enforces that the user makes these decisions, not only because it simplifies the solver implementations (quite a lot, actually), but it also shows how easily we can build a custom solver for our favourite problem (consult the other tutorials for examples).
From here on, the rest is standard ODE-solver machinery:
dt0 = timestep.initial(lambda y: vf(y, t=t0), (u0,)) # or use e.g. dt0=0.1
solution = ivpsolve.solve_and_save_every_step(
vf, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver
)
# Look at the solution
print("u =", solution.u, "\n")
print("solution =", solution)
u = [[0.10000001] [0.10100418] [0.10778747] [0.11259064] [0.1195111 ] [0.12856498] [0.14064184] [0.15482746]] solution = Solution(t=[0. 0.02221729 0.16736329 0.26535517 0.40031433 0.56703913 0.77451193 1. ],u=[[0.10000001] [0.10100418] [0.10778747] [0.11259064] [0.1195111 ] [0.12856498] [0.14064184] [0.15482746]],output_scale=[1. 1. 1. 1. 1. 1. 1.],marginals=Normal(mean=Array([[ 0.10000001, 0.04499995, 0.01799997, 0.00517498, -0.00036 ], [ 0.10100418, 0.04540117, 0.01812172, 0.00567627, 0. ], [ 0.10778747, 0.04808467, 0.01884079, 0.00537109, 0.01171875], [ 0.11259064, 0.049957 , 0.01936531, 0.00497246, -0.01367188], [ 0.1195111 , 0.0526141 , 0.02000809, 0.00482178, 0.00292969], [ 0.12856498, 0.05601802, 0.02082443, 0.00473022, -0.0065918 ], [ 0.14064184, 0.06043086, 0.02169228, 0.00402832, 0.00048828], [ 0.15482746, 0.0654263 , 0.02262986, 0.00439072, 0.0012207 ]], dtype=float32), cholesky=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[-5.40208434e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.15543694e-11, -2.08688801e-15, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 2.03164791e-06, 1.49874602e-09, 6.67584743e-07, 0.00000000e+00, 0.00000000e+00], [ 4.48327919e-04, 7.70927841e-07, 2.94516416e-04, -8.68116040e-05, 0.00000000e+00], [ 2.69052312e-02, 1.72822794e-04, 4.19507436e-02, -4.37253006e-02, -2.88737584e-02]], [[-1.77799777e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-6.97354423e-08, 3.25504597e-12, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 6.54255418e-05, 8.30784117e-08, -3.42869425e-05, 0.00000000e+00, 0.00000000e+00], [-2.46706302e-04, 3.33999924e-06, -9.87979467e-04, -8.88306531e-04, 0.00000000e+00], [-6.63579106e-02, -7.50574764e-05, 5.70422746e-02, -2.70840526e-02, 6.22604676e-02]], [[ 1.29194234e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 5.00512769e-08, -1.60090725e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.79636840e-05, 2.95818609e-07, -9.59915633e-05, 0.00000000e+00, 0.00000000e+00], [-7.19281030e-04, 6.11662563e-06, -1.14296260e-03, -1.57083164e-03, 0.00000000e+00], [ 3.64819318e-02, -1.27584208e-04, 7.61297941e-02, -2.71387883e-02, -7.39993230e-02]], [[ 4.99107330e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.89903531e-07, -8.41459454e-12, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-1.72293672e-04, 8.95660435e-07, 7.75992521e-05, 0.00000000e+00, 0.00000000e+00], [-1.06356828e-03, 4.10452449e-05, 2.14708387e-03, -1.72005978e-03, 0.00000000e+00], [ 8.52025300e-02, 3.69630288e-04, -4.09095883e-02, -3.89944613e-02, 7.98959583e-02]], [[ 1.15785338e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 4.30070997e-07, -4.87518949e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.39636812e-04, 5.06329457e-07, -1.16019786e-04, 0.00000000e+00, 0.00000000e+00], [-2.43879436e-03, 2.01948224e-05, -3.09603708e-03, -2.54919217e-03, 0.00000000e+00], [ 1.00888848e-01, 1.13697417e-04, 2.68875603e-02, -4.26336154e-02, -9.28991139e-02]], [[ 3.96288533e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.42407907e-06, -4.30233432e-10, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-8.15153413e-04, 1.37843017e-07, 2.13126623e-04, 0.00000000e+00, 0.00000000e+00], [-7.41946790e-03, 5.74358273e-06, 6.37688674e-03, -3.21438373e-03, 0.00000000e+00], [ 1.05688281e-01, 4.79742885e-05, 1.83327310e-03, -7.38707334e-02, 1.02758668e-01]], [[ 1.80909610e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 8.73967583e-05, -1.10021747e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.17991229e-03, 1.64914120e-04, -9.13287004e-05, 0.00000000e+00, 0.00000000e+00], [-4.59418520e-02, 1.36794588e-02, 4.88730986e-03, -1.18775666e-03, 0.00000000e+00], [-3.24873686e-01, 2.38483191e-01, 1.66421458e-01, 1.01145357e-01, -5.79760484e-02]]], dtype=float32)),posterior=MarkovSeq(init=Normal(mean=Array([[ 0.1 , 0.04499996, 0.01799997, 0.00517498, -0.00036 ], [ 0.10100419, 0.04540117, 0.01812308, 0.00605633, 0.04968028], [ 0.1077876 , 0.04808472, 0.01872323, 0.00102929, -0.04753059], [ 0.1125906 , 0.04995698, 0.01945034, 0.00811751, 0.03415165], [ 0.11951123, 0.05261415, 0.01992667, 0.00250585, -0.02289676], [ 0.12856483, 0.05601796, 0.02089252, 0.00635088, 0.01156631], [ 0.14064199, 0.06043091, 0.02164194, 0.00293808, -0.00945758], [ 0.15482746, 0.0654263 , 0.02262986, 0.00439072, 0.0012207 ]], dtype=float32), cholesky=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[ 7.91196362e-12, -6.26754343e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 3.15480975e-12, -2.50078170e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.28596570e-07, 2.60301704e-06, 7.83470796e-07, 0.00000000e+00, 0.00000000e+00], [-8.28969060e-05, 6.56676071e-04, 3.70238035e-04, 9.56417643e-05, 0.00000000e+00], [-8.39758571e-03, 6.65223226e-02, 6.66565746e-02, 5.16286157e-02, 2.98265778e-02]], [[-5.31815907e-08, 4.20827490e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.08617550e-08, 1.65053734e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 4.34265094e-05, -3.43636610e-04, -1.03429273e-04, 0.00000000e+00, 0.00000000e+00], [ 1.55896263e-03, -1.23361507e-02, -6.93825353e-03, -1.78607600e-03, 0.00000000e+00], [ 2.23839302e-02, -1.77125201e-01, -1.77221656e-01, -1.37049288e-01, -7.90882632e-02]], [[-1.47024046e-08, 2.58920068e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-5.67988678e-09, 1.00309364e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 2.09359896e-05, -3.68675857e-04, -1.93847183e-04, 0.00000000e+00, 0.00000000e+00], [ 8.14600440e-04, -1.43450079e-02, -6.99010864e-03, -3.51823447e-03, 0.00000000e+00], [ 1.22726113e-02, -2.16120183e-01, -1.05068095e-01, -1.76491082e-01, -9.51540545e-02]], [[-1.02916523e-07, 1.20186542e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.91499420e-08, 4.57293908e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 6.16994221e-05, -7.20527314e-04, -1.68529252e-04, 0.00000000e+00, 0.00000000e+00], [ 1.76727260e-03, -2.06382945e-02, -9.10638180e-03, -2.97560450e-03, 0.00000000e+00], [ 1.96096674e-02, -2.29002550e-01, -1.91143245e-01, -1.50939584e-01, -1.00273095e-01]], [[-2.24601379e-07, 2.86749696e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-8.34740703e-08, 1.06509356e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 9.80283585e-05, -1.25156587e-03, -2.25203912e-04, 0.00000000e+00, 0.00000000e+00], [ 2.28672149e-03, -2.91953534e-02, -1.11514935e-02, -4.06775298e-03, 0.00000000e+00], [ 2.06811372e-02, -2.64042199e-01, -2.09386870e-01, -1.54978439e-01, -1.17186829e-01]], [[-6.27795714e-07, 7.83733321e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.25172016e-07, 2.81641042e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.72547021e-04, -2.15387065e-03, -3.82097991e-04, 0.00000000e+00, 0.00000000e+00], [ 3.23386281e-03, -4.03687544e-02, -1.60418041e-02, -4.46381560e-03, 0.00000000e+00], [ 2.34576724e-02, -2.92826146e-01, -2.44874418e-01, -1.65839300e-01, -1.21937253e-01]], [[ 1.80909610e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 8.73967583e-05, -1.10021747e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.17991229e-03, 1.64914120e-04, -9.13287004e-05, 0.00000000e+00, 0.00000000e+00], [-4.59418520e-02, 1.36794588e-02, 4.88730986e-03, -1.18775666e-03, 0.00000000e+00], [-3.24873686e-01, 2.38483191e-01, 1.66421458e-01, 1.01145357e-01, -5.79760484e-02]]], dtype=float32)), conditional=Conditional(matmul=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[-4.86423407e-04, 3.33144308e-05, -9.61733463e-07, 1.40875684e-08, -9.01645217e-11], [-1.94083419e-04, 1.32924888e-05, -3.83732527e-07, 5.62095304e-09, -3.59757363e-11], [ 2.30000019e+01, -1.56321883e+00, 4.47964221e-02, -6.51569455e-04, 4.14227407e-06], [ 6.46875244e+03, -4.35523224e+02, 1.23669634e+01, -1.78297862e-01, 1.12391950e-03], [ 7.75261312e+05, -5.07181523e+04, 1.39963989e+03, -1.96157627e+01, 1.20235138e-01]], [[ 1.09494829e+00, -7.16572329e-02, 2.03129416e-03, -3.00331758e-05, 1.99358240e-07], [ 4.29453254e-01, -2.81049181e-02, 7.96700537e-04, -1.17794107e-05, 7.81909719e-08], [-2.15406433e+02, 2.53473263e+01, -1.00328398e+00, 1.84721146e-02, -1.42598336e-04], [ 8.36796191e+03, -7.62727890e+01, -6.27954912e+00, 1.61855996e-01, -1.27943710e-03], [ 1.72093906e+05, -1.99243203e+04, 1.10720935e+03, -2.63099365e+01, 2.40958646e-01]], [[ 6.71298325e-01, -6.28148615e-02, 2.53447867e-03, -5.30355574e-05, 4.95307347e-07], [ 2.60068268e-01, -2.43351627e-02, 9.81884892e-04, -2.05465549e-05, 1.91887480e-07], [ 1.62859634e+02, -4.51557779e+00, -1.84634626e-01, 1.02132000e-02, -1.42791236e-04], [-2.26262970e+02, 2.25851379e+02, -1.31403513e+01, 2.98927635e-01, -2.68927449e-03], [-1.50927078e+05, 8.86273340e+03, -5.84144163e+00, -8.11936951e+00, 1.48641229e-01]], [[ 9.81848598e-01, -1.11015067e-01, 5.43915993e-03, -1.38917516e-04, 1.59018214e-06], [ 3.73579711e-01, -4.22396958e-02, 2.06952542e-03, -5.28562014e-05, 6.05042374e-07], [-1.21940346e+02, 1.96845894e+01, -1.20995104e+00, 3.60479318e-02, -4.59014787e-04], [ 4.76848584e+03, -3.27190491e+02, 8.80935001e+00, -1.03741810e-01, 3.24589433e-04], [ 7.12875703e+04, -9.10046875e+03, 5.97274353e+02, -2.01622849e+01, 2.86757648e-01]], [[ 7.35635102e-01, -1.07922606e-01, 6.84794830e-03, -2.25702694e-04, 3.31629758e-06], [ 2.73242831e-01, -4.00865637e-02, 2.54358863e-03, -8.38345586e-05, 1.23179893e-06], [-5.16176453e+01, 1.27113428e+01, -1.06558919e+00, 4.16843295e-02, -6.83933846e-04], [ 3.27595264e+03, -3.16063721e+02, 1.30493059e+01, -2.82863349e-01, 2.84119160e-03], [ 2.38872520e+04, -4.09079517e+03, 3.59470032e+02, -1.57596073e+01, 2.85550624e-01]], [[ 1.16081941e+00, -1.74137533e-01, 1.13278311e-02, -3.84805462e-04, 5.87288423e-06], [ 4.17145073e-01, -6.25770167e-02, 4.07070108e-03, -1.38281350e-04, 2.11044335e-06], [-1.02933937e+02, 1.93165150e+01, -1.47601581e+00, 5.64130694e-02, -9.38129611e-04], [ 2.00339417e+03, -1.85996704e+02, 6.70605612e+00, -1.01921700e-01, 3.03049892e-04], [ 2.90794004e+04, -4.62802539e+03, 3.73554291e+02, -1.60114765e+01, 2.96656549e-01]]], dtype=float32), noise=Normal(mean=Array([[ 1.00000009e-01, 4.49999534e-02, 1.79999657e-02, 5.17498422e-03, -3.59997677e-04], [ 1.01055026e-01, 4.54214588e-02, -2.38666391e+00, -6.76534851e+02, -8.11509609e+04], [-1.19528659e-02, 1.12091354e-03, 2.30246468e+01, -9.38217651e+02, -1.84020996e+04], [ 3.56175378e-02, 2.01367829e-02, -1.92029400e+01, 1.54244213e+01, 1.75712988e+04], [-6.14004966e-04, 6.90816902e-03, 1.46195974e+01, -5.94909912e+02, -8.66763184e+03], [ 3.14781368e-02, 1.99562721e-02, 6.53521538e+00, -4.41913208e+02, -3.12007788e+03], [-2.79463753e-02, -1.51986271e-04, 1.47280397e+01, -2.98158630e+02, -4.20787842e+03]], dtype=float32), cholesky=Array([[[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], [[-4.5030164e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.7967076e-11, 2.0868664e-15, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 1.4843881e-06, -1.6807128e-10, -5.4886624e-07, 0.0000000e+00, 0.0000000e+00], [ 2.6374924e-04, -6.1671322e-08, -2.1967612e-04, -7.7254605e-05, 0.0000000e+00], [-2.7380288e-03, -5.2228825e-06, -1.9210855e-02, -3.5316776e-02, 2.7775509e-02]], [[ 2.1233083e-08, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 8.3279001e-09, 3.2550275e-12, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-2.0843874e-05, 3.8254706e-09, 9.4913503e-06, 0.0000000e+00, 0.0000000e+00], [-1.0687469e-04, -2.7474189e-07, -3.7957347e-04, 4.2701344e-04, 0.0000000e+00], [ 4.9146257e-02, -3.0767740e-05, -4.6220221e-02, -1.9498399e-02, -4.9097214e-02]], [[ 7.6285659e-08, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 2.9553897e-08, -1.6008949e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-4.0608593e-05, -5.3421192e-09, -2.4359892e-05, 0.0000000e+00, 0.0000000e+00], [-1.7863374e-04, 7.3233018e-08, 8.5846701e-04, -7.2435755e-04, 0.0000000e+00], [ 5.7514153e-02, 1.6428090e-05, 5.0146773e-02, 2.6745848e-02, -5.8165241e-02]], [[ 2.3060400e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 8.7741626e-08, -8.4134756e-12, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-8.2769300e-05, -8.1520675e-08, 2.8327368e-05, 0.0000000e+00, 0.0000000e+00], [-1.0342022e-04, 1.7863732e-06, -4.5782633e-04, 8.8693429e-04, 0.0000000e+00], [ 6.7466736e-02, 3.3338429e-04, -5.9037719e-02, -2.0410839e-02, -6.3329771e-02]], [[ 6.0008074e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 2.2289281e-07, 4.8750965e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.5172355e-04, -5.3159690e-09, 4.6051595e-05, 0.0000000e+00, 0.0000000e+00], [ 2.1772059e-04, 9.2368595e-08, -6.1109255e-04, 1.2781352e-03, 0.0000000e+00], [ 7.9684466e-02, 1.5346206e-05, -5.9320591e-02, -2.7698932e-02, -7.1936317e-02]], [[ 9.2313553e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 3.3173228e-07, -4.3023324e-10, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.7790566e-04, 6.0121490e-09, 5.7936431e-05, 0.0000000e+00, 0.0000000e+00], [-2.2470618e-04, 2.3389499e-07, -6.4530666e-04, 1.4193625e-03, 0.0000000e+00], [ 7.4166320e-02, -1.8121380e-05, -6.7320928e-02, -2.7070098e-02, -7.4938416e-02]]], dtype=float32)))),num_steps=[1 2 3 4 5 6 7],)