Skip to content

matfree.stochtrace

matfree.stochtrace

Stochastic estimation of traces, diagonals, and more.

matfree.stochtrace.estimator(integrand: Callable, /, sampler: Callable) -> Callable

Construct a stochastic trace-/diagonal-estimator.

Parameters:

Name Type Description Default
integrand Callable

The integrand function. For example, the return-value of integrand_trace. But any other integrand works, too.

required
sampler Callable

The sample function. Usually, either sampler_normal or sampler_rademacher.

required

Returns:

Type Description
estimate

A function that maps a random key to an estimate. This function can be compiled, vectorised, differentiated, or looped over as the user desires.

Source code in matfree/stochtrace.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def estimator(integrand: Callable, /, sampler: Callable) -> Callable:
    """Construct a stochastic trace-/diagonal-estimator.

    Parameters
    ----------
    integrand
        The integrand function. For example, the return-value of
        [integrand_trace][matfree.stochtrace.integrand_trace].
        But any other integrand works, too.
    sampler
        The sample function. Usually, either
        [sampler_normal][matfree.stochtrace.sampler_normal] or
        [sampler_rademacher][matfree.stochtrace.sampler_rademacher].

    Returns
    -------
    estimate
        A function that maps a random key to an estimate.
        This function can be compiled, vectorised, differentiated,
        or looped over as the user desires.

    """

    def estimate(matvecs, key, *parameters):
        samples = sampler(key)
        Qs = func.vmap(lambda vec: integrand(matvecs, vec, *parameters))(samples)
        return tree.tree_map(lambda s: np.mean(s, axis=0), Qs)

    return estimate

matfree.stochtrace.integrand_diagonal()

Construct the integrand for estimating the diagonal.

When plugged into the Monte-Carlo estimator, the result will be an Array or PyTree of Arrays with the same tree-structure as matvec(*args_like) where *args_like is an argument of the sampler.

Source code in matfree/stochtrace.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def integrand_diagonal():
    """Construct the integrand for estimating the diagonal.

    When plugged into the Monte-Carlo estimator,
    the result will be an Array or PyTree of Arrays with the
    same tree-structure as
    ``
    matvec(*args_like)
    ``
    where ``*args_like`` is an argument of the sampler.
    """

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        return unflatten(v_flat * Qv_flat)

    return integrand

matfree.stochtrace.integrand_frobeniusnorm_squared()

Construct the integrand for estimating the squared Frobenius norm.

Source code in matfree/stochtrace.py
85
86
87
88
89
90
91
92
93
def integrand_frobeniusnorm_squared():
    """Construct the integrand for estimating the squared Frobenius norm."""

    def integrand(matvec, vec, *parameters):
        x = matvec(vec, *parameters)
        v_flat, unflatten = tree.ravel_pytree(x)
        return linalg.inner(v_flat, v_flat)

    return integrand

matfree.stochtrace.integrand_trace()

Construct the integrand for estimating the trace.

Source code in matfree/stochtrace.py
59
60
61
62
63
64
65
66
67
68
def integrand_trace():
    """Construct the integrand for estimating the trace."""

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        return linalg.inner(v_flat, Qv_flat)

    return integrand

matfree.stochtrace.integrand_trace_and_diagonal()

Construct the integrand for estimating the trace and diagonal jointly.

Source code in matfree/stochtrace.py
71
72
73
74
75
76
77
78
79
80
81
82
def integrand_trace_and_diagonal():
    """Construct the integrand for estimating the trace and diagonal jointly."""

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        trace_form = linalg.inner(v_flat, Qv_flat)
        diagonal_form = unflatten(v_flat * Qv_flat)
        return {"trace": trace_form, "diagonal": diagonal_form}

    return integrand

matfree.stochtrace.integrand_wrap_moments(integrand, /, moments)

Wrap an integrand into another integrand that computes moments.

Parameters:

Name Type Description Default
integrand

Any integrand function compatible with Hutchinson-style estimation.

required
moments

Any Pytree (tuples, lists, dictionaries) whose leafs that are valid inputs to lambda m: x**m for an array x, usually, with data-type float (but that depends on the wrapped integrand). For example, moments=4, moments=(1,2), or moments={"a": 1, "b": 2}.

required

Returns:

Type Description
integrand

An integrand function compatible with Hutchinson-style estimation whose output has a PyTree-structure that mirrors the structure of the moments argument.

Source code in matfree/stochtrace.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def integrand_wrap_moments(integrand, /, moments):
    """Wrap an integrand into another integrand that computes moments.

    Parameters
    ----------
    integrand
        Any integrand function compatible with Hutchinson-style estimation.
    moments
        Any Pytree (tuples, lists, dictionaries) whose leafs that are
        valid inputs to ``lambda m: x**m`` for an array ``x``,
        usually, with data-type ``float``
        (but that depends on the wrapped integrand).
        For example, ``moments=4``, ``moments=(1,2)``,
        or ``moments={"a": 1, "b": 2}``.

    Returns
    -------
    integrand
        An integrand function compatible with Hutchinson-style estimation whose
        output has a PyTree-structure that mirrors the structure of the ``moments``
        argument.

    """

    def integrand_wrapped(vec, *parameters):
        Qs = integrand(vec, *parameters)
        return tree.tree_map(moment_fun, Qs)

    def moment_fun(x, /):
        return tree.tree_map(lambda m: x**m, moments)

    return integrand_wrapped

matfree.stochtrace.sampler_normal(*args_like, num)

Construct a function that samples from a standard-normal distribution.

Source code in matfree/stochtrace.py
130
131
132
def sampler_normal(*args_like, num):
    """Construct a function that samples from a standard-normal distribution."""
    return _sampler_from_jax_random(prng.normal, *args_like, num=num)

matfree.stochtrace.sampler_rademacher(*args_like, num)

Construct a function that samples from a Rademacher distribution.

Source code in matfree/stochtrace.py
135
136
137
def sampler_rademacher(*args_like, num):
    """Construct a function that samples from a Rademacher distribution."""
    return _sampler_from_jax_random(prng.rademacher, *args_like, num=num)