Skip to content

strategy

Interface for estimation strategies.

ExtrapolationImpl ¤

Bases: ABC, Generic[T, R, S]

Extrapolation model interface.

Source code in probdiffeq/solvers/strategies/strategy.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ExtrapolationImpl(abc.ABC, Generic[T, R, S]):
    """Extrapolation model interface."""

    @abc.abstractmethod
    def initial_condition(self, tcoeffs, /) -> T:
        """Compute an initial condition from a set of Taylor coefficients."""
        raise NotImplementedError

    @abc.abstractmethod
    def init(self, solution: T, /) -> tuple[R, S]:
        """Initialise a state from a solution."""
        raise NotImplementedError

    @abc.abstractmethod
    def begin(self, state: R, aux: S, /, dt) -> tuple[R, S]:
        """Begin the extrapolation."""
        raise NotImplementedError

    @abc.abstractmethod
    def complete(self, state: R, aux: S, /, output_scale) -> tuple[R, S]:
        """Complete the extrapolation."""
        raise NotImplementedError

    @abc.abstractmethod
    def extract(self, state: R, aux: S, /) -> T:
        """Extract a solution from a state."""
        raise NotImplementedError

    @abc.abstractmethod
    def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale):
        """Interpolate."""
        raise NotImplementedError

    @abc.abstractmethod
    def right_corner(self, state: R, aux: S, /) -> _interp.InterpRes[tuple[R, S]]:
        """Process the state at checkpoint t=t_n."""
        raise NotImplementedError

begin(state: R, aux: S, /, dt) -> tuple[R, S] ¤

Begin the extrapolation.

Source code in probdiffeq/solvers/strategies/strategy.py
27
28
29
30
@abc.abstractmethod
def begin(self, state: R, aux: S, /, dt) -> tuple[R, S]:
    """Begin the extrapolation."""
    raise NotImplementedError

complete(state: R, aux: S, /, output_scale) -> tuple[R, S] ¤

Complete the extrapolation.

Source code in probdiffeq/solvers/strategies/strategy.py
32
33
34
35
@abc.abstractmethod
def complete(self, state: R, aux: S, /, output_scale) -> tuple[R, S]:
    """Complete the extrapolation."""
    raise NotImplementedError

extract(state: R, aux: S) -> T ¤

Extract a solution from a state.

Source code in probdiffeq/solvers/strategies/strategy.py
37
38
39
40
@abc.abstractmethod
def extract(self, state: R, aux: S, /) -> T:
    """Extract a solution from a state."""
    raise NotImplementedError

init(solution: T) -> tuple[R, S] ¤

Initialise a state from a solution.

Source code in probdiffeq/solvers/strategies/strategy.py
22
23
24
25
@abc.abstractmethod
def init(self, solution: T, /) -> tuple[R, S]:
    """Initialise a state from a solution."""
    raise NotImplementedError

initial_condition(tcoeffs) -> T ¤

Compute an initial condition from a set of Taylor coefficients.

Source code in probdiffeq/solvers/strategies/strategy.py
17
18
19
20
@abc.abstractmethod
def initial_condition(self, tcoeffs, /) -> T:
    """Compute an initial condition from a set of Taylor coefficients."""
    raise NotImplementedError

interpolate(state_t0, marginal_t1, *, dt0, dt1, output_scale) ¤

Interpolate.

Source code in probdiffeq/solvers/strategies/strategy.py
42
43
44
45
@abc.abstractmethod
def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale):
    """Interpolate."""
    raise NotImplementedError

right_corner(state: R, aux: S) -> _interp.InterpRes[tuple[R, S]] ¤

Process the state at checkpoint t=t_n.

Source code in probdiffeq/solvers/strategies/strategy.py
47
48
49
50
@abc.abstractmethod
def right_corner(self, state: R, aux: S, /) -> _interp.InterpRes[tuple[R, S]]:
    """Process the state at checkpoint t=t_n."""
    raise NotImplementedError

Strategy ¤

Estimation strategy.

Source code in probdiffeq/solvers/strategies/strategy.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class Strategy:
    """Estimation strategy."""

    def __init__(
        self,
        extrapolation: ExtrapolationImpl,
        correction,
        *,
        string_repr,
        is_suitable_for_save_at,
        is_suitable_for_save_every_step,
        is_suitable_for_offgrid_marginals,
    ):
        # Content
        self.extrapolation = extrapolation
        self.correction = correction

        # Some meta-information
        self.string_repr = string_repr
        self.is_suitable_for_save_at = is_suitable_for_save_at
        self.is_suitable_for_save_every_step = is_suitable_for_save_every_step
        self.is_suitable_for_offgrid_marginals = is_suitable_for_offgrid_marginals

    def __repr__(self):
        return self.string_repr

    def initial_condition(self, taylor_coefficients, /):
        """Construct an initial condition from a set of Taylor coefficients."""
        return self.extrapolation.initial_condition(taylor_coefficients)

    def init(self, t, posterior, /) -> _State:
        """Initialise a state from a posterior."""
        rv, extra = self.extrapolation.init(posterior)
        rv, corr = self.correction.init(rv)
        return _State(t=t, hidden=rv, aux_extra=extra, aux_corr=corr)

    def predict_error(self, state: _State, /, *, dt, vector_field):
        """Predict the error of an upcoming step."""
        hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
        t = state.t + dt
        error, observed, corr = self.correction.estimate_error(
            hidden, state.aux_corr, vector_field=vector_field, t=t
        )
        state = _State(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr)
        return error, observed, state

    def complete(self, state, /, *, output_scale):
        """Complete the step after the error has been predicted."""
        hidden, extra = self.extrapolation.complete(
            state.hidden, state.aux_extra, output_scale=output_scale
        )
        hidden, corr = self.correction.complete(hidden, state.aux_corr)
        return _State(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr)

    def extract(self, state: _State, /):
        """Extract the solution from a state."""
        hidden = self.correction.extract(state.hidden, state.aux_corr)
        sol = self.extrapolation.extract(hidden, state.aux_extra)
        return state.t, sol

    def case_right_corner(self, state_t1: _State) -> _interp.InterpRes:
        """Process the solution in case t=t_n."""
        _tmp = self.extrapolation.right_corner(state_t1.hidden, state_t1.aux_extra)
        step_from, solution, interp_from = _tmp

        def _state(x):
            t = state_t1.t
            corr_like = tree_util.tree_map(np.empty_like, state_t1.aux_corr)
            return _State(t=t, hidden=x[0], aux_extra=x[1], aux_corr=corr_like)

        step_from = _state(step_from)
        solution = _state(solution)
        interp_from = _state(interp_from)
        return _interp.InterpRes(step_from, solution, interp_from)

    def case_interpolate(
        self, t, *, s0: _State, s1: _State, output_scale
    ) -> _interp.InterpRes[_State]:
        """Process the solution in case t>t_n."""
        # Interpolate
        step_from, solution, interp_from = self.extrapolation.interpolate(
            state_t0=(s0.hidden, s0.aux_extra),
            marginal_t1=s1.hidden,
            dt0=t - s0.t,
            dt1=s1.t - t,
            output_scale=output_scale,
        )

        # Turn outputs into valid states

        def _state(t_, x):
            corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr)
            return _State(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like)

        step_from = _state(s1.t, step_from)
        solution = _state(t, solution)
        interp_from = _state(t, interp_from)
        return _interp.InterpRes(step_from, solution, interp_from)

    def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
        """Compute offgrid_marginals."""
        if not self.is_suitable_for_offgrid_marginals:
            raise NotImplementedError

        dt0 = t - t0
        dt1 = t1 - t
        state_t0 = self.init(t0, posterior_t0)

        _acc, (marginals, _aux), _prev = self.extrapolation.interpolate(
            state_t0=(state_t0.hidden, state_t0.aux_extra),
            marginal_t1=marginals_t1,
            dt0=dt0,
            dt1=dt1,
            output_scale=output_scale,
        )

        u = impl.hidden_model.qoi(marginals)
        return u, marginals

case_interpolate(t, *, s0: _State, s1: _State, output_scale) -> _interp.InterpRes[_State] ¤

Process the solution in case t>t_n.

Source code in probdiffeq/solvers/strategies/strategy.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def case_interpolate(
    self, t, *, s0: _State, s1: _State, output_scale
) -> _interp.InterpRes[_State]:
    """Process the solution in case t>t_n."""
    # Interpolate
    step_from, solution, interp_from = self.extrapolation.interpolate(
        state_t0=(s0.hidden, s0.aux_extra),
        marginal_t1=s1.hidden,
        dt0=t - s0.t,
        dt1=s1.t - t,
        output_scale=output_scale,
    )

    # Turn outputs into valid states

    def _state(t_, x):
        corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr)
        return _State(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like)

    step_from = _state(s1.t, step_from)
    solution = _state(t, solution)
    interp_from = _state(t, interp_from)
    return _interp.InterpRes(step_from, solution, interp_from)

case_right_corner(state_t1: _State) -> _interp.InterpRes ¤

Process the solution in case t=t_n.

Source code in probdiffeq/solvers/strategies/strategy.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def case_right_corner(self, state_t1: _State) -> _interp.InterpRes:
    """Process the solution in case t=t_n."""
    _tmp = self.extrapolation.right_corner(state_t1.hidden, state_t1.aux_extra)
    step_from, solution, interp_from = _tmp

    def _state(x):
        t = state_t1.t
        corr_like = tree_util.tree_map(np.empty_like, state_t1.aux_corr)
        return _State(t=t, hidden=x[0], aux_extra=x[1], aux_corr=corr_like)

    step_from = _state(step_from)
    solution = _state(solution)
    interp_from = _state(interp_from)
    return _interp.InterpRes(step_from, solution, interp_from)

complete(state, /, *, output_scale) ¤

Complete the step after the error has been predicted.

Source code in probdiffeq/solvers/strategies/strategy.py
106
107
108
109
110
111
112
def complete(self, state, /, *, output_scale):
    """Complete the step after the error has been predicted."""
    hidden, extra = self.extrapolation.complete(
        state.hidden, state.aux_extra, output_scale=output_scale
    )
    hidden, corr = self.correction.complete(hidden, state.aux_corr)
    return _State(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr)

extract(state: _State) ¤

Extract the solution from a state.

Source code in probdiffeq/solvers/strategies/strategy.py
114
115
116
117
118
def extract(self, state: _State, /):
    """Extract the solution from a state."""
    hidden = self.correction.extract(state.hidden, state.aux_corr)
    sol = self.extrapolation.extract(hidden, state.aux_extra)
    return state.t, sol

init(t, posterior) -> _State ¤

Initialise a state from a posterior.

Source code in probdiffeq/solvers/strategies/strategy.py
90
91
92
93
94
def init(self, t, posterior, /) -> _State:
    """Initialise a state from a posterior."""
    rv, extra = self.extrapolation.init(posterior)
    rv, corr = self.correction.init(rv)
    return _State(t=t, hidden=rv, aux_extra=extra, aux_corr=corr)

initial_condition(taylor_coefficients) ¤

Construct an initial condition from a set of Taylor coefficients.

Source code in probdiffeq/solvers/strategies/strategy.py
86
87
88
def initial_condition(self, taylor_coefficients, /):
    """Construct an initial condition from a set of Taylor coefficients."""
    return self.extrapolation.initial_condition(taylor_coefficients)

offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale) ¤

Compute offgrid_marginals.

Source code in probdiffeq/solvers/strategies/strategy.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
    """Compute offgrid_marginals."""
    if not self.is_suitable_for_offgrid_marginals:
        raise NotImplementedError

    dt0 = t - t0
    dt1 = t1 - t
    state_t0 = self.init(t0, posterior_t0)

    _acc, (marginals, _aux), _prev = self.extrapolation.interpolate(
        state_t0=(state_t0.hidden, state_t0.aux_extra),
        marginal_t1=marginals_t1,
        dt0=dt0,
        dt1=dt1,
        output_scale=output_scale,
    )

    u = impl.hidden_model.qoi(marginals)
    return u, marginals

predict_error(state: _State, /, *, dt, vector_field) ¤

Predict the error of an upcoming step.

Source code in probdiffeq/solvers/strategies/strategy.py
 96
 97
 98
 99
100
101
102
103
104
def predict_error(self, state: _State, /, *, dt, vector_field):
    """Predict the error of an upcoming step."""
    hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
    t = state.t + dt
    error, observed, corr = self.correction.estimate_error(
        hidden, state.aux_corr, vector_field=vector_field, t=t
    )
    state = _State(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr)
    return error, observed, state