An easy example¶
Let's have a look at an easy example.
"""Solve the logistic equation."""
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.impl import impl
jax.config.update("jax_platform_name", "cpu")
Create a problem:
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the vector field."""
return 0.5 * y * (1 - y)
u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 1.0
ProbDiffEq contains three levels of implementations:
Low: Implementations of random-variable-arithmetic (marginalisation, conditioning, etc.)
Medium: Probabilistic IVP solver components (this is what you're here for.)
High: ODE-solving routines.
There are several random-variable implementations (read: state-space model factorisations) which model different correlations between variables. All factorisations can be used interchangeably, but they have different speed, stability, and uncertainty-quantification properties. Since the chosen implementation powers almost everything, we choose one (and only one) of them, assign it to a global variable, and call it the "impl(ementation)".
impl.select("dense", ode_shape=(1,))
# But don't worry, this configuration does not make the library any less light-weight.
# It merely affects the shapes of the arrays
# describing means and covariances of Gaussian
# random variables, and assigns functions that know how to manipulate those parameters.
#
Configuring a probabilistic IVP solver is a little more involved than configuring your favourite Runge-Kutta method: we must choose a prior distribution and a correction scheme, then we put them together as a filter or smoother, wrap everything into a solver, and (finally) make the solver adaptive.
ibm = ivpsolvers.prior_ibm(num_derivatives=4)
ts0 = ivpsolvers.correction_ts1(ode_order=1)
strategy = ivpsolvers.strategy_smoother(ibm, ts0)
solver = ivpsolvers.solver(strategy)
adaptive_solver = ivpsolve.adaptive(solver)
Why so many layers?
- Prior distributions incorporate prior knowledge; better prior knowledge should `improve the simulation'' (which might mean something different for different applications, hence the quotation marks)
- Different correction schemes imply different stability concerns
- Filters and smoothers are optimised estimators for either forward-only or time-series estimation
- Calibration schemes affect the behaviour of the solver
- Not all solution routines expect adaptive solvers.
The granularity of construction a solver is an asset, not a drawback.
Finally, we must prepare one last component before we can solve the differential equation:
The probabilistic IVP solvers in ProbDiffEq implement state-space-model-based IVP solvers; this means that as an initial condition, we must provide a data structure that represents the initial state in this model. For all current solvers, this amounts to computing a $\nu$-th order Taylor approximation of the IVP solution and to wrapping this approximation into a state-space-model variable.
Use the following functions:
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4)
output_scale = 1.0 # or any other value with the same shape
init = solver.initial_condition(tcoeffs, output_scale)
Other software packages that implement probabilistic IVP solvers do a lot of this work implicitly; probdiffeq enforces that the user makes these decisions, not only because it simplifies the solver implementations (quite a lot, actually), but it also shows how easily we can build a custom solver for our favourite problem (consult the other tutorials for examples).
From here on, the rest is standard ODE-solver machinery:
dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), (u0,)) # or use e.g. dt0=0.1
solution = ivpsolve.solve_adaptive_save_every_step(
vf, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver
)
# Look at the solution
print("u =", solution.u, "\n")
print("solution =", solution)
u = [[0.10000001] [0.10100418] [0.10778747] [0.11259064] [0.1195111 ] [0.12856498] [0.14064184] [0.15482746]] solution = _Solution(t=[0. 0.02221729 0.16736329 0.26535517 0.40031433 0.56703913 0.77451193 1. ],u=[[0.10000001] [0.10100418] [0.10778747] [0.11259064] [0.1195111 ] [0.12856498] [0.14064184] [0.15482746]],output_scale=[1. 1. 1. 1. 1. 1. 1.],marginals=Normal(mean=Array([[ 0.10000001, 0.04499995, 0.01799997, 0.00517498, -0.00036 ], [ 0.10100418, 0.04540117, 0.01812172, 0.00567627, 0. ], [ 0.10778747, 0.04808467, 0.01884079, 0.00537109, 0.01171875], [ 0.11259064, 0.049957 , 0.01936531, 0.00497246, -0.01367188], [ 0.1195111 , 0.0526141 , 0.02000809, 0.00482178, 0.00292969], [ 0.12856498, 0.05601802, 0.02082443, 0.00473022, -0.0065918 ], [ 0.14064184, 0.06043086, 0.02169228, 0.00402832, 0.00048828], [ 0.15482746, 0.0654263 , 0.02262986, 0.00439072, 0.0012207 ]], dtype=float32), cholesky=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[-5.40208434e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.15543694e-11, -2.08688801e-15, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 2.03164791e-06, 1.49874602e-09, 6.67584743e-07, 0.00000000e+00, 0.00000000e+00], [ 4.48327919e-04, 7.70927841e-07, 2.94516416e-04, -8.68116040e-05, 0.00000000e+00], [ 2.69052312e-02, 1.72822794e-04, 4.19507436e-02, -4.37253006e-02, -2.88737584e-02]], [[-1.77799777e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-6.97354423e-08, 3.25504597e-12, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 6.54255418e-05, 8.30784117e-08, -3.42869425e-05, 0.00000000e+00, 0.00000000e+00], [-2.46706302e-04, 3.33999924e-06, -9.87979467e-04, -8.88306531e-04, 0.00000000e+00], [-6.63579106e-02, -7.50574764e-05, 5.70422746e-02, -2.70840526e-02, 6.22604676e-02]], [[ 1.29194234e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 5.00512769e-08, -1.60090725e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.79636840e-05, 2.95818609e-07, -9.59915633e-05, 0.00000000e+00, 0.00000000e+00], [-7.19281030e-04, 6.11662563e-06, -1.14296260e-03, -1.57083164e-03, 0.00000000e+00], [ 3.64819318e-02, -1.27584208e-04, 7.61297941e-02, -2.71387883e-02, -7.39993230e-02]], [[ 4.99107330e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.89903531e-07, -8.41459454e-12, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-1.72293672e-04, 8.95660435e-07, 7.75992521e-05, 0.00000000e+00, 0.00000000e+00], [-1.06356828e-03, 4.10452449e-05, 2.14708387e-03, -1.72005978e-03, 0.00000000e+00], [ 8.52025300e-02, 3.69630288e-04, -4.09095883e-02, -3.89944613e-02, 7.98959583e-02]], [[ 1.15785338e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 4.30070997e-07, -4.87518949e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.39636812e-04, 5.06329457e-07, -1.16019786e-04, 0.00000000e+00, 0.00000000e+00], [-2.43879436e-03, 2.01948224e-05, -3.09603708e-03, -2.54919217e-03, 0.00000000e+00], [ 1.00888848e-01, 1.13697417e-04, 2.68875603e-02, -4.26336154e-02, -9.28991139e-02]], [[ 3.96288533e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.42407907e-06, -4.30233432e-10, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-8.15153413e-04, 1.37843017e-07, 2.13126623e-04, 0.00000000e+00, 0.00000000e+00], [-7.41946790e-03, 5.74358273e-06, 6.37688674e-03, -3.21438373e-03, 0.00000000e+00], [ 1.05688281e-01, 4.79742885e-05, 1.83327310e-03, -7.38707334e-02, 1.02758668e-01]], [[ 1.80909610e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 8.73967583e-05, -1.10021747e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.17991229e-03, 1.64914120e-04, -9.13287004e-05, 0.00000000e+00, 0.00000000e+00], [-4.59418520e-02, 1.36794588e-02, 4.88730986e-03, -1.18775666e-03, 0.00000000e+00], [-3.24873686e-01, 2.38483191e-01, 1.66421458e-01, 1.01145357e-01, -5.79760484e-02]]], dtype=float32)),posterior=MarkovSeq(init=Normal(mean=Array([[ 0.1 , 0.04499996, 0.01799997, 0.00517498, -0.00036 ], [ 0.10100419, 0.04540117, 0.01812308, 0.00605633, 0.04968028], [ 0.1077876 , 0.04808472, 0.01872323, 0.00102929, -0.04753059], [ 0.1125906 , 0.04995698, 0.01945034, 0.00811751, 0.03415165], [ 0.11951123, 0.05261415, 0.01992667, 0.00250585, -0.02289676], [ 0.12856483, 0.05601796, 0.02089252, 0.00635088, 0.01156631], [ 0.14064199, 0.06043091, 0.02164194, 0.00293808, -0.00945758], [ 0.15482746, 0.0654263 , 0.02262986, 0.00439072, 0.0012207 ]], dtype=float32), cholesky=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[ 7.91196362e-12, -6.26754343e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 3.15480975e-12, -2.50078170e-11, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.28596570e-07, 2.60301704e-06, 7.83470796e-07, 0.00000000e+00, 0.00000000e+00], [-8.28969060e-05, 6.56676071e-04, 3.70238035e-04, 9.56417643e-05, 0.00000000e+00], [-8.39758571e-03, 6.65223226e-02, 6.66565746e-02, 5.16286157e-02, 2.98265778e-02]], [[-5.31815907e-08, 4.20827490e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.08617550e-08, 1.65053734e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 4.34265094e-05, -3.43636610e-04, -1.03429273e-04, 0.00000000e+00, 0.00000000e+00], [ 1.55896263e-03, -1.23361507e-02, -6.93825353e-03, -1.78607600e-03, 0.00000000e+00], [ 2.23839302e-02, -1.77125201e-01, -1.77221656e-01, -1.37049288e-01, -7.90882632e-02]], [[-1.47024046e-08, 2.58920068e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-5.67988678e-09, 1.00309364e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 2.09359896e-05, -3.68675857e-04, -1.93847183e-04, 0.00000000e+00, 0.00000000e+00], [ 8.14600440e-04, -1.43450079e-02, -6.99010864e-03, -3.51823447e-03, 0.00000000e+00], [ 1.22726113e-02, -2.16120183e-01, -1.05068095e-01, -1.76491082e-01, -9.51540545e-02]], [[-1.02916523e-07, 1.20186542e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-3.91499420e-08, 4.57293908e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 6.16994221e-05, -7.20527314e-04, -1.68529252e-04, 0.00000000e+00, 0.00000000e+00], [ 1.76727260e-03, -2.06382945e-02, -9.10638180e-03, -2.97560450e-03, 0.00000000e+00], [ 1.96096674e-02, -2.29002550e-01, -1.91143245e-01, -1.50939584e-01, -1.00273095e-01]], [[-2.24601379e-07, 2.86749696e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-8.34740703e-08, 1.06509356e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 9.80283585e-05, -1.25156587e-03, -2.25203912e-04, 0.00000000e+00, 0.00000000e+00], [ 2.28672149e-03, -2.91953534e-02, -1.11514935e-02, -4.06775298e-03, 0.00000000e+00], [ 2.06811372e-02, -2.64042199e-01, -2.09386870e-01, -1.54978439e-01, -1.17186829e-01]], [[-6.27795714e-07, 7.83733321e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.25172016e-07, 2.81641042e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.72547021e-04, -2.15387065e-03, -3.82097991e-04, 0.00000000e+00, 0.00000000e+00], [ 3.23386281e-03, -4.03687544e-02, -1.60418041e-02, -4.46381560e-03, 0.00000000e+00], [ 2.34576724e-02, -2.92826146e-01, -2.44874418e-01, -1.65839300e-01, -1.21937253e-01]], [[ 1.80909610e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 8.73967583e-05, -1.10021747e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [-2.17991229e-03, 1.64914120e-04, -9.13287004e-05, 0.00000000e+00, 0.00000000e+00], [-4.59418520e-02, 1.36794588e-02, 4.88730986e-03, -1.18775666e-03, 0.00000000e+00], [-3.24873686e-01, 2.38483191e-01, 1.66421458e-01, 1.01145357e-01, -5.79760484e-02]]], dtype=float32)), conditional=Conditional(matmul=Array([[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [[-4.86423407e-04, 3.33144308e-05, -9.61733463e-07, 1.40875684e-08, -9.01645217e-11], [-1.94083419e-04, 1.32924888e-05, -3.83732527e-07, 5.62095304e-09, -3.59757363e-11], [ 2.30000019e+01, -1.56321883e+00, 4.47964221e-02, -6.51569455e-04, 4.14227407e-06], [ 6.46875244e+03, -4.35523224e+02, 1.23669634e+01, -1.78297862e-01, 1.12391950e-03], [ 7.75261312e+05, -5.07181523e+04, 1.39963989e+03, -1.96157627e+01, 1.20235138e-01]], [[ 1.09494829e+00, -7.16572329e-02, 2.03129416e-03, -3.00331758e-05, 1.99358240e-07], [ 4.29453254e-01, -2.81049181e-02, 7.96700537e-04, -1.17794107e-05, 7.81909719e-08], [-2.15406433e+02, 2.53473263e+01, -1.00328398e+00, 1.84721146e-02, -1.42598336e-04], [ 8.36796191e+03, -7.62727890e+01, -6.27954912e+00, 1.61855996e-01, -1.27943710e-03], [ 1.72093906e+05, -1.99243203e+04, 1.10720935e+03, -2.63099365e+01, 2.40958646e-01]], [[ 6.71298325e-01, -6.28148615e-02, 2.53447867e-03, -5.30355574e-05, 4.95307347e-07], [ 2.60068268e-01, -2.43351627e-02, 9.81884892e-04, -2.05465549e-05, 1.91887480e-07], [ 1.62859634e+02, -4.51557779e+00, -1.84634626e-01, 1.02132000e-02, -1.42791236e-04], [-2.26262970e+02, 2.25851379e+02, -1.31403513e+01, 2.98927635e-01, -2.68927449e-03], [-1.50927078e+05, 8.86273340e+03, -5.84144163e+00, -8.11936951e+00, 1.48641229e-01]], [[ 9.81848598e-01, -1.11015067e-01, 5.43915993e-03, -1.38917516e-04, 1.59018214e-06], [ 3.73579711e-01, -4.22396958e-02, 2.06952542e-03, -5.28562014e-05, 6.05042374e-07], [-1.21940346e+02, 1.96845894e+01, -1.20995104e+00, 3.60479318e-02, -4.59014787e-04], [ 4.76848584e+03, -3.27190491e+02, 8.80935001e+00, -1.03741810e-01, 3.24589433e-04], [ 7.12875703e+04, -9.10046875e+03, 5.97274353e+02, -2.01622849e+01, 2.86757648e-01]], [[ 7.35635102e-01, -1.07922606e-01, 6.84794830e-03, -2.25702694e-04, 3.31629758e-06], [ 2.73242831e-01, -4.00865637e-02, 2.54358863e-03, -8.38345586e-05, 1.23179893e-06], [-5.16176453e+01, 1.27113428e+01, -1.06558919e+00, 4.16843295e-02, -6.83933846e-04], [ 3.27595264e+03, -3.16063721e+02, 1.30493059e+01, -2.82863349e-01, 2.84119160e-03], [ 2.38872520e+04, -4.09079517e+03, 3.59470032e+02, -1.57596073e+01, 2.85550624e-01]], [[ 1.16081941e+00, -1.74137533e-01, 1.13278311e-02, -3.84805462e-04, 5.87288423e-06], [ 4.17145073e-01, -6.25770167e-02, 4.07070108e-03, -1.38281350e-04, 2.11044335e-06], [-1.02933937e+02, 1.93165150e+01, -1.47601581e+00, 5.64130694e-02, -9.38129611e-04], [ 2.00339417e+03, -1.85996704e+02, 6.70605612e+00, -1.01921700e-01, 3.03049892e-04], [ 2.90794004e+04, -4.62802539e+03, 3.73554291e+02, -1.60114765e+01, 2.96656549e-01]]], dtype=float32), noise=Normal(mean=Array([[ 1.00000009e-01, 4.49999534e-02, 1.79999657e-02, 5.17498422e-03, -3.59997677e-04], [ 1.01055026e-01, 4.54214588e-02, -2.38666391e+00, -6.76534851e+02, -8.11509609e+04], [-1.19528659e-02, 1.12091354e-03, 2.30246468e+01, -9.38217651e+02, -1.84020996e+04], [ 3.56175378e-02, 2.01367829e-02, -1.92029400e+01, 1.54244213e+01, 1.75712988e+04], [-6.14004966e-04, 6.90816902e-03, 1.46195974e+01, -5.94909912e+02, -8.66763184e+03], [ 3.14781368e-02, 1.99562721e-02, 6.53521538e+00, -4.41913208e+02, -3.12007788e+03], [-2.79463753e-02, -1.51986271e-04, 1.47280397e+01, -2.98158630e+02, -4.20787842e+03]], dtype=float32), cholesky=Array([[[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], [[-4.5030164e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.7967076e-11, 2.0868664e-15, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 1.4843881e-06, -1.6807128e-10, -5.4886624e-07, 0.0000000e+00, 0.0000000e+00], [ 2.6374924e-04, -6.1671322e-08, -2.1967612e-04, -7.7254605e-05, 0.0000000e+00], [-2.7380288e-03, -5.2228825e-06, -1.9210855e-02, -3.5316776e-02, 2.7775509e-02]], [[ 2.1233083e-08, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 8.3279001e-09, 3.2550275e-12, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-2.0843874e-05, 3.8254706e-09, 9.4913503e-06, 0.0000000e+00, 0.0000000e+00], [-1.0687469e-04, -2.7474189e-07, -3.7957347e-04, 4.2701344e-04, 0.0000000e+00], [ 4.9146257e-02, -3.0767740e-05, -4.6220221e-02, -1.9498399e-02, -4.9097214e-02]], [[ 7.6285659e-08, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 2.9553897e-08, -1.6008949e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-4.0608593e-05, -5.3421192e-09, -2.4359892e-05, 0.0000000e+00, 0.0000000e+00], [-1.7863374e-04, 7.3233018e-08, 8.5846701e-04, -7.2435755e-04, 0.0000000e+00], [ 5.7514153e-02, 1.6428090e-05, 5.0146773e-02, 2.6745848e-02, -5.8165241e-02]], [[ 2.3060400e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 8.7741626e-08, -8.4134756e-12, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-8.2769300e-05, -8.1520675e-08, 2.8327368e-05, 0.0000000e+00, 0.0000000e+00], [-1.0342022e-04, 1.7863732e-06, -4.5782633e-04, 8.8693429e-04, 0.0000000e+00], [ 6.7466736e-02, 3.3338429e-04, -5.9037719e-02, -2.0410839e-02, -6.3329771e-02]], [[ 6.0008074e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 2.2289281e-07, 4.8750965e-11, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.5172355e-04, -5.3159690e-09, 4.6051595e-05, 0.0000000e+00, 0.0000000e+00], [ 2.1772059e-04, 9.2368595e-08, -6.1109255e-04, 1.2781352e-03, 0.0000000e+00], [ 7.9684466e-02, 1.5346206e-05, -5.9320591e-02, -2.7698932e-02, -7.1936317e-02]], [[ 9.2313553e-07, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [ 3.3173228e-07, -4.3023324e-10, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], [-1.7790566e-04, 6.0121490e-09, 5.7936431e-05, 0.0000000e+00, 0.0000000e+00], [-2.2470618e-04, 2.3389499e-07, -6.4530666e-04, 1.4193625e-03, 0.0000000e+00], [ 7.4166320e-02, -1.8121380e-05, -6.7320928e-02, -2.7070098e-02, -7.4938416e-02]]], dtype=float32)))),num_steps=[1 2 3 4 5 6 7],)