Skip to content

matfree.stochtrace

matfree.stochtrace

Stochastic estimation of traces, diagonals, and more.

matfree.stochtrace.estimator_leave_one_out(integrand: Callable, /, sampler: Callable) -> Callable

Construct a leave-one-out stochastic estimator.

Parameters:

Name Type Description Default
integrand Callable

An integrand that accepts (matvec, samples, *parameters) where samples has shape (num, n). For example, the return-value of leave_one_out_xtrace.

required
sampler Callable

The sample function, e.g. the return-value of sampler_normal or sampler_signs.

required

Returns:

Type Description
estimate

A function estimate(matvec, key, *parameters) -> result. This function can be compiled, vectorised, differentiated, or looped over as the user desires.

Source code in matfree/stochtrace.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def estimator_leave_one_out(integrand: Callable, /, sampler: Callable) -> Callable:
    """Construct a leave-one-out stochastic estimator.

    Parameters
    ----------
    integrand
        An integrand that accepts ``(matvec, samples, *parameters)`` where
        ``samples`` has shape ``(num, n)``. For example, the return-value of
        [leave_one_out_xtrace][matfree.stochtrace.leave_one_out_xtrace].
    sampler
        The sample function, e.g. the return-value of
        [sampler_normal][matfree.stochtrace.sampler_normal] or
        [sampler_signs][matfree.stochtrace.sampler_signs].

    Returns
    -------
    estimate
        A function ``estimate(matvec, key, *parameters) -> result``.
        This function can be compiled, vectorised, differentiated,
        or looped over as the user desires.

    """

    def estimate(matvec, key, *parameters):
        samples = sampler(key)
        return np.mean(integrand(matvec, samples, *parameters), axis=0)

    return estimate

matfree.stochtrace.estimator_leave_one_out_mean_and_sem(integrand: Callable, /, sampler: Callable) -> Callable

Construct a LOO estimator that returns mean and standard error.

Like estimator_leave_one_out, but returns (mean, sem) where sem = std(loo_estimates) / sqrt(num_samples). The LOO integrand produces one estimate per leave-one-out, so their standard deviation is a natural uncertainty measure.

Parameters:

Name Type Description Default
integrand Callable

Any integrand compatible with estimator_leave_one_out.

required
sampler Callable

The sample function.

required

Returns:

Type Description
estimate

A function that returns (mean, sem), both scalars (or arrays with the same shape as a single LOO estimate).

Source code in matfree/stochtrace.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def estimator_leave_one_out_mean_and_sem(
    integrand: Callable, /, sampler: Callable
) -> Callable:
    """Construct a LOO estimator that returns mean and standard error.

    Like [estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out],
    but returns ``(mean, sem)`` where ``sem = std(loo_estimates) / sqrt(num_samples)``.
    The LOO integrand produces one estimate per leave-one-out, so their
    standard deviation is a natural uncertainty measure.

    Parameters
    ----------
    integrand
        Any integrand compatible with
        [estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out].
    sampler
        The sample function.

    Returns
    -------
    estimate
        A function that returns ``(mean, sem)``, both scalars (or arrays
        with the same shape as a single LOO estimate).
    """

    def estimate(matvec, key, *parameters):
        samples = sampler(key)
        Qs = integrand(matvec, samples, *parameters)
        mean = np.mean(Qs, axis=0)
        sem = np.std(Qs, axis=0) / np.sqrt(Qs.shape[0])
        return mean, sem

    return estimate

matfree.stochtrace.estimator_monte_carlo(integrand: Callable, /, sampler: Callable) -> Callable

Construct a stochastic trace-/diagonal-estimator.

Parameters:

Name Type Description Default
integrand Callable

An integrand function with signature integrand(matvec, vec, *parameters), where vec is a single sample vector (contrast with estimator_leave_one_out, which passes the full sample batch). Use any of the monte_carlo_* constructors, e.g. monte_carlo_trace, monte_carlo_diagonal, monte_carlo_frobeniusnorm_squared, or any of the monte_carlo_funm_* functions from matfree.funm.

required
sampler Callable

The sample function. See below for recommendations.

required

Returns:

Type Description
estimate

A function that maps a random key to an estimate. This function can be compiled, vectorised, differentiated, or looped over as the user desires.

Notes

The statistical efficiency of the estimator for a given sampler depends on properties of the operator, but we can provide some general advice. For an n-dimensional operator (see references): - n > O(100), use sampler_signs. - n < O(100), use sampler_signs if the operator is known to be diagonal-dominant or sampler_sphere otherwise. - If the operator is complex-valued, pass a complex dtype to the sampler to approximately double the efficiency.

References
Source code in matfree/stochtrace.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def estimator_monte_carlo(integrand: Callable, /, sampler: Callable) -> Callable:
    """Construct a stochastic trace-/diagonal-estimator.

    Parameters
    ----------
    integrand
        An integrand function with signature ``integrand(matvec, vec, *parameters)``,
        where ``vec`` is a single sample vector (contrast with
        [estimator_leave_one_out][matfree.stochtrace.estimator_leave_one_out],
        which passes the full sample batch).
        Use any of the ``monte_carlo_*`` constructors, e.g.
        [monte_carlo_trace][matfree.stochtrace.monte_carlo_trace],
        [monte_carlo_diagonal][matfree.stochtrace.monte_carlo_diagonal],
        [monte_carlo_frobeniusnorm_squared][matfree.stochtrace.monte_carlo_frobeniusnorm_squared],
        or any of the ``monte_carlo_funm_*`` functions from [matfree.funm][matfree.funm].
    sampler
        The sample function. See below for recommendations.

    Returns
    -------
    estimate
        A function that maps a random key to an estimate.
        This function can be compiled, vectorised, differentiated,
        or looped over as the user desires.

    Notes
    -----
    The statistical efficiency of the estimator for a given sampler depends on properties
    of the operator, but we can provide some general advice. For an `n`-dimensional operator (see references):
    - `n > O(100)`, use [sampler_signs][matfree.stochtrace.sampler_signs].
    - `n < O(100)`, use [sampler_signs][matfree.stochtrace.sampler_signs] if the operator is known to be diagonal-dominant or [sampler_sphere][matfree.stochtrace.sampler_sphere] otherwise.
    - If the operator is complex-valued, pass a complex dtype to the sampler to approximately double the efficiency.

    References
    ----------
    - Epperly, E. (2023). [Stochastic trace estimation](https://www.ethanepperly.com/index.php/2023/01/26/stochastic-trace-estimation/).
    - Epperly, E. (2024). [Don't use Gaussians in stochastic trace estimation](https://www.ethanepperly.com/index.php/2024/01/28/dont-use-gaussians-in-stochastic-trace-estimation/).
    """

    def estimate(matvecs, key, *parameters):
        samples = sampler(key)
        Qs = func.vmap(lambda vec: integrand(matvecs, vec, *parameters))(samples)
        return tree.tree_map(lambda s: np.mean(s, axis=0), Qs)

    return estimate

matfree.stochtrace.estimator_monte_carlo_mean_and_sem(integrand: Callable, /, sampler: Callable) -> Callable

Construct a stochastic estimator that returns mean and standard error.

Like estimator_monte_carlo, but returns (mean, sem) where sem = std(samples) / sqrt(num_samples) is the standard error of the mean -- the direct uncertainty on the estimate. The number of samples is already encoded in the sampler, so the caller does not need to track it separately.

Parameters:

Name Type Description Default
integrand Callable

Any integrand compatible with estimator_monte_carlo.

required
sampler Callable

The sample function. See estimator_monte_carlo for recommendations.

required

Returns:

Type Description
estimate

A function that returns (mean, sem), both with the same PyTree structure as the integrand output.

Source code in matfree/stochtrace.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def estimator_monte_carlo_mean_and_sem(
    integrand: Callable, /, sampler: Callable
) -> Callable:
    """Construct a stochastic estimator that returns mean and standard error.

    Like [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo],
    but returns ``(mean, sem)`` where ``sem = std(samples) / sqrt(num_samples)``
    is the standard error of the mean -- the direct uncertainty on the estimate.
    The number of samples is already encoded in the sampler,
    so the caller does not need to track it separately.

    Parameters
    ----------
    integrand
        Any integrand compatible with
        [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo].
    sampler
        The sample function. See [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo]
        for recommendations.

    Returns
    -------
    estimate
        A function that returns ``(mean, sem)``, both with the same
        PyTree structure as the integrand output.
    """

    def estimate(matvecs, key, *parameters):
        samples = sampler(key)
        Qs = func.vmap(lambda vec: integrand(matvecs, vec, *parameters))(samples)
        mean = tree.tree_map(lambda s: np.mean(s, axis=0), Qs)
        sem = tree.tree_map(lambda s: np.std(s, axis=0) / np.sqrt(s.shape[0]), Qs)
        return mean, sem

    return estimate

matfree.stochtrace.leave_one_out_xnystrace(*, nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None, apply_resphering: bool = True, qr_r: Callable[[Array], Array] | None = None) -> Callable

Construct an integrand for estimating the trace of a positive semi-definite operator using the XNysTrace algorithm (Epperly et al. 2024).

Parameters:

Name Type Description Default
nystrom Callable[[Callable, Array], tuple[Array, Array, Array]] | None

A callable with signature (matvec_flat, Omega) -> (nystrom_left, downdate, shift), where Omega has shape (n, num_samples), nystrom_left and downdate have shape (n, num_samples), and shift is a scalar. nystrom_left @ nystrom_left.T.conj() approximates the operator (shifted by shift * I), and subtracting outer(downdate[:, i], downdate[:, i].conj()) approximates it leaving out the i-th column of Omega. Usually the return value of nystrom_shifted_cholesky or nystrom_eigh (default: nystrom_eigh).

None
apply_resphering bool

If True (default), project test vectors onto the range of the residual matrix, reducing the variance of the trace estimate. Requires test vectors drawn from a rotationally invariant distribution (e.g. normal or sphere). See Epperly, 2025 for more details.

True
qr_r Callable[[Array], Array] | None

A callable that computes the R factor of a QR decomposition, used if apply_resphering is True. If not provided, linalg.qr_r is used.

None

Returns:

Type Description
integrand

An integrand function compatible with estimator_leave_one_out whose input has the signature (matvec, samples, *params) and whose output is a vector of trace estimates with shape (samples.shape[0],). The matvec must be a positive semi-definite operator. That is, vdot(v, matvec(v)) is real and non-negative for all vectors v, and vdot(x, matvec(y)) = vdot(matvec(x), y) for all vectors x and y.

References
  • Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation. SIAM J Matrix Anal A. 45.1: 1-23. doi: 10.1137/23M1548323 arXiv: 2301.07825
  • Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis. arXiv: 2512.15929
Source code in matfree/stochtrace.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def leave_one_out_xnystrace(
    *,
    nystrom: Callable[[Callable, Array], tuple[Array, Array, Array]] | None = None,
    apply_resphering: bool = True,
    qr_r: Callable[[Array], Array] | None = None,
) -> Callable:
    """Construct an integrand for estimating the trace of a positive semi-definite operator using the XNysTrace algorithm (Epperly et al. 2024).

    Parameters
    ----------
    nystrom
        A callable with signature ``(matvec_flat, Omega) -> (nystrom_left, downdate, shift)``,
        where ``Omega`` has shape ``(n, num_samples)``, ``nystrom_left`` and ``downdate``
        have shape ``(n, num_samples)``, and ``shift`` is a scalar.
        ``nystrom_left @ nystrom_left.T.conj()`` approximates the operator (shifted by ``shift * I``),
        and subtracting ``outer(downdate[:, i], downdate[:, i].conj())``
        approximates it leaving out the ``i``-th column of ``Omega``.
        Usually the return value of [`nystrom_shifted_cholesky`][matfree.stochtrace.nystrom_shifted_cholesky]
        or [`nystrom_eigh`][matfree.stochtrace.nystrom_eigh] (default: `nystrom_eigh`).
    apply_resphering
        If ``True`` (default), project test vectors onto the range of the
        residual matrix, reducing the variance of the trace estimate.
        Requires test vectors drawn from a rotationally invariant distribution
        (e.g. normal or sphere). See Epperly, 2025 for more details.
    qr_r
        A callable that computes the R factor of a QR decomposition, used if `apply_resphering` is `True`.
        If not provided, `linalg.qr_r` is used.

    Returns
    -------
    integrand
        An integrand function compatible with `estimator_leave_one_out` whose input
        has the signature ``(matvec, samples, *params)`` and whose output is a vector
        of trace estimates with shape ``(samples.shape[0],)``.
        The `matvec` must be a positive semi-definite operator. That is,
        `vdot(v, matvec(v))` is real and non-negative for all vectors `v`,
        and `vdot(x, matvec(y)) = vdot(matvec(x), y)` for all vectors `x` and `y`.

    References
    ----------
    - Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation.
        SIAM J Matrix Anal A. 45.1: 1-23.
        doi: [10.1137/23M1548323](https://doi.org/10.1137/23M1548323)
        arXiv: [2301.07825](https://arxiv.org/abs/2301.07825)
    - Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
        arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
    """
    # NOTE: The paper and thesis use the shifted Nystrom approximation with a
    # Cholesky decomposition, but empirically, this is brittle and fails for
    # many low-rank operators. The eigh-based Nystrom approximation seems to be
    # more robust.
    if nystrom is None:
        nystrom = nystrom_eigh()

    # NOTE: The paper computes R via the QR decomposition, while for efficiency, the
    # thesis uses the upper Cholesky factor of the Gram matrix. We use the QR approach
    # because it may be less brittle and prone to NaNs than Cholesky.
    if qr_r is None:
        qr_r = linalg.qr_r

    def integrand(matvec, samples, *params):
        sample0 = tree.tree_map(lambda s: s[0], samples)
        _, unflatten = tree.ravel_pytree(sample0)

        Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
        n, num_samples = Omega.shape

        if num_samples > n:
            raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

        def matvec_flat(v):
            return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

        if num_samples == n:
            # It's faster and more accurate to compute the trace exactly and deterministically
            # when num_samples == n
            B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
            trace_samples = np.ones((num_samples,), dtype=B_mat.dtype) * linalg.trace(
                B_mat
            )
            return trace_samples.real

        F, Z, shift = nystrom(matvec_flat, Omega)

        if apply_resphering:
            # Ensure T (the R factor) is square
            T = qr_r(Omega)[:num_samples, :num_samples]
            S = _qr_leave_one_out_factor(T)
            # Omega projected onto the subspace spanned by Q_i, i.e. Q from qr(Omega_{-i}) leaving out Omega[:, i]
            X = T - S * func.vmap(linalg.vdot, in_axes=1)(S, T)
            # residual is B - B_hat_{-i}, where B_hat_{-i} approximates B leaving out Omega[:, i]
            rank_residual = n - num_samples + 1
            # squared norm of each sample after projection to the subspace spanned by the residual
            sqnorm_Omega = np.sum(linalg.abs2(Omega), axis=0)
            sqnorm_X = np.sum(linalg.abs2(X), axis=0)
            sqnorm_samples_projected = sqnorm_Omega - sqnorm_X
            sqnorm_samples_projected = np.where(
                sqnorm_samples_projected == 0.0, 1.0, sqnorm_samples_projected
            )
            residual_scale = rank_residual / sqnorm_samples_projected
        else:
            residual_scale = 1.0

        # Compute the trace estimate, correcting for shift in _nystrom_shifted
        tr_B_hat = np.sum(linalg.abs2(F)) - shift * n
        tr_B_hat_loo = tr_B_hat - np.sum(linalg.abs2(Z), axis=0)
        tr_residual_loo = linalg.abs2(func.vmap(linalg.vdot, in_axes=1)(Z, Omega))
        return tr_B_hat_loo + residual_scale * tr_residual_loo

    return integrand

matfree.stochtrace.leave_one_out_xtrace(*, apply_resphering: bool = True) -> Callable

Construct an integrand for estimating the trace using the XTrace algorithm (Epperly et al. 2024).

Parameters:

Name Type Description Default
apply_resphering bool

If True (default), project test vectors onto the range of the residual matrix, reducing the variance of the trace estimate. Requires test vectors drawn from a rotationally invariant distribution (e.g. normal or sphere). See Epperly, 2025 for more details.

True

Returns:

Type Description
integrand

An integrand function compatible with estimator_leave_one_out whose input has the signature (matvec, samples, *params) and whose output is a vector of trace estimates with one estimate per sample.

Notes

The number of samples must be less than or equal to the dimension of the operator. Additionally, the algorithm assumes that the samples are unique. For low-dimensional operators, samples generated from sampler_signs may violate this assumption, and it is recommended to use a different sampler instead.

References
  • Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation. SIAM J Matrix Anal A. 45.1: 1-23. doi: 10.1137/23M1548323 arXiv: 2301.07825
  • Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis. arXiv: 2512.15929
Source code in matfree/stochtrace.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def leave_one_out_xtrace(*, apply_resphering: bool = True) -> Callable:
    """Construct an integrand for estimating the trace using the XTrace algorithm (Epperly et al. 2024).

    Parameters
    ----------
    apply_resphering
        If ``True`` (default), project test vectors onto the range of the
        residual matrix, reducing the variance of the trace estimate.
        Requires test vectors drawn from a rotationally invariant distribution
        (e.g. normal or sphere). See Epperly, 2025 for more details.

    Returns
    -------
    integrand
        An integrand function compatible with `estimator_leave_one_out` whose input
        has the signature ``(matvec, samples, *params)`` and whose output is a vector
        of trace estimates with one estimate per sample.

    Notes
    -----
    The number of samples must be less than or equal to the dimension of the operator.
    Additionally, the algorithm assumes that the samples are unique. For low-dimensional
    operators, samples generated from `sampler_signs` may violate this assumption, and
    it is recommended to use a different sampler instead.

    References
    ----------
    - Epperly EN, Tropp JA, Webber RJ (2024). XTrace: Making the most of every sample in stochastic trace estimation.
        SIAM J Matrix Anal A. 45.1: 1-23.
        doi: [10.1137/23M1548323](https://doi.org/10.1137/23M1548323)
        arXiv: [2301.07825](https://arxiv.org/abs/2301.07825)
    - Epperly EN (2025). Make the most of what you have: Resource-efficient randomized algorithms for matrix computations. PhD Thesis.
        arXiv: [2512.15929](https://arxiv.org/abs/2512.15929)
    """

    def integrand(matvec, samples, *params):
        sample0 = tree.tree_map(lambda s: s[0], samples)
        _, unflatten = tree.ravel_pytree(sample0)

        Omega = func.vmap(lambda s: tree.ravel_pytree(s)[0])(samples).T
        n, num_samples = Omega.shape

        if num_samples > n:
            raise ValueError(_error_num_samples(num_samples, maxval=n, minval=1))

        def matvec_flat(v):
            return tree.ravel_pytree(matvec(unflatten(v), *params))[0]

        if 2 * num_samples >= n:
            # It's faster, more accurate, and allocates less memory to compute the trace exactly
            # and deterministically on the materialized operator when num_samples >= n/2
            B_mat = _materialize_operator(matvec_flat, Omega[:, 0])
            return np.ones((num_samples,), dtype=B_mat.dtype) * linalg.trace(B_mat)

        matmat = func.vmap(matvec_flat, in_axes=-1, out_axes=-1)

        Y = matmat(Omega)
        Q, R = linalg.qr_reduced(Y)
        Z = matmat(Q)

        def _trace_exact():
            tr_B = linalg.vdot(Q, Z)
            return np.ones((num_samples,), dtype=R.dtype) * tr_B

        def _trace_estimate():
            S = _qr_leave_one_out_factor(R)

            Q_H = Q.T.conj()
            # tr(H) == tr(B_hat), where B_hat = Q @ Q.H @ B is a low-rank approximation to the operator B
            H = Q_H @ Z
            W = Q_H @ Omega
            T = Z.T.conj() @ Omega
            W_vd_S = func.vmap(linalg.vdot, in_axes=1)(W, S)
            # Omega projected onto the subspace spanned by Q_i, i.e. Q formed leaving out Omega[:, i]
            X = W - S * W_vd_S.conj()
            T_vd_X = func.vmap(linalg.vdot, in_axes=1)(T, X)
            X_vd_HX = func.vmap(linalg.vdot, in_axes=1)(X, H @ X)
            S_vd_R = func.vmap(linalg.vdot, in_axes=1)(S, R)

            if apply_resphering:
                # residual is B - B_hat_{-i}, where B_hat_{-i} approximates B leaving out Omega[:, i]
                rank_residual = n - num_samples + 1
                # squared norm of each sample after projection to the subspace spanned by the residual
                sqnorm_Omega = np.sum(linalg.abs2(Omega), axis=0)
                sqnorm_X = np.sum(linalg.abs2(X), axis=0)
                sqnorm_samples_projected = sqnorm_Omega - sqnorm_X
                has_zero_norm = sqnorm_samples_projected == 0.0
                sqnorm_samples_projected = np.where(
                    has_zero_norm, 1.0, sqnorm_samples_projected
                )
                residual_scale = rank_residual / sqnorm_samples_projected
            else:
                residual_scale = 1.0

            tr_B_hat = linalg.trace(H)
            # tr(B_hat) leaving out one sample
            tr_B_hat_loo = tr_B_hat - func.vmap(linalg.vdot, in_axes=1)(S, H @ S)
            # Hutchinson estimate of tr(B - B_hat_{-i}) using samples[i, :] as probes.
            tr_residual_loo = -T_vd_X + X_vd_HX + S_vd_R * W_vd_S
            return tr_B_hat_loo + residual_scale * tr_residual_loo

        Y_rank = np.sum(np.abs(linalg.diagonal(R)) > np.finfo_eps(R.dtype))

        # NOTE: If Y_rank < num_samples, then Y is rank-deficient because either:
        # 1. rank(B) < num_samples, and/or
        # 2. the samples are not unique.
        # This check assumes samples are unique, which can be violated for low n and Rademacher samples.
        # If rank(B) < num_samples, then B_hat=B_hat_{-i}=B, so the residual is zero and tr(B_hat)=tr(B).
        return control_flow.cond(Y_rank < num_samples, _trace_exact, _trace_estimate)

    return integrand

matfree.stochtrace.monte_carlo_diagonal()

Construct the integrand for estimating the diagonal.

Use with estimator_monte_carlo. The result will be an Array or PyTree of Arrays with the same tree-structure as matvec(*args_like) where *args_like is an argument of the sampler.

Source code in matfree/stochtrace.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def monte_carlo_diagonal():
    """Construct the integrand for estimating the diagonal.

    Use with [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo].
    The result will be an Array or PyTree of Arrays with the same tree-structure as
    ``matvec(*args_like)`` where ``*args_like`` is an argument of the sampler.
    """

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        return unflatten(v_flat.conj() * Qv_flat)

    return integrand

matfree.stochtrace.monte_carlo_frobeniusnorm_squared()

Construct the integrand for estimating the squared Frobenius norm.

Use with estimator_monte_carlo.

Source code in matfree/stochtrace.py
594
595
596
597
598
599
600
601
602
603
604
605
def monte_carlo_frobeniusnorm_squared():
    """Construct the integrand for estimating the squared Frobenius norm.

    Use with [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo].
    """

    def integrand(matvec, vec, *parameters):
        x = matvec(vec, *parameters)
        v_flat, _unflatten = tree.ravel_pytree(x)
        return linalg.inner(v_flat.conj(), v_flat)

    return integrand

matfree.stochtrace.monte_carlo_trace()

Construct the integrand for estimating the trace.

Use with estimator_monte_carlo.

Source code in matfree/stochtrace.py
562
563
564
565
566
567
568
569
570
571
572
573
574
def monte_carlo_trace():
    """Construct the integrand for estimating the trace.

    Use with [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo].
    """

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, _unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        return linalg.inner(v_flat.conj(), Qv_flat)

    return integrand

matfree.stochtrace.monte_carlo_trace_and_diagonal()

Construct the integrand for estimating the trace and diagonal jointly.

Use with estimator_monte_carlo.

Source code in matfree/stochtrace.py
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def monte_carlo_trace_and_diagonal():
    """Construct the integrand for estimating the trace and diagonal jointly.

    Use with [estimator_monte_carlo][matfree.stochtrace.estimator_monte_carlo].
    """

    def integrand(matvec, v, *parameters):
        Qv = matvec(v, *parameters)
        v_flat, unflatten = tree.ravel_pytree(v)
        Qv_flat, _unflatten = tree.ravel_pytree(Qv)
        trace_form = linalg.inner(v_flat.conj(), Qv_flat)
        diagonal_form = unflatten(v_flat.conj() * Qv_flat)
        return {"trace": trace_form, "diagonal": diagonal_form}

    return integrand

matfree.stochtrace.nystrom_eigh(eigenvalues_rtol: float | None = None, leverage_rtol: float | None = None, symmetrize_input: bool = True)

Construct a Nystrom approximation of an operator using a Hermitian eigendecomposition.

Parameters:

Name Type Description Default
eigenvalues_rtol float | None

A relative tolerance used to determine which eigenvalues are close enough to 0.

None
leverage_rtol float | None

A relative tolerance used in computing the leverage scores to determine which test vectors are essential.

None
symmetrize_input bool

If True (default), internally symmetrizes before computing the eigendecomposition.

True

Returns:

Type Description
nystrom

A function that computes the Nystrom approximation of an operator using a Hermitian eigendecomposition. The function has the signature (matvec_flat, Omega) -> (nystrom_left, downdate, shift), where nystrom_left is a left factor of the Nystrom approximation matrix of shape (n, num_samples), such that nystrom_left @ nystrom_left.T.conj() approximates the operator, downdate is a matrix of shape (n, num_samples) whose columns are downdate vectors for the Nystrom approximation, and shift=0 is the shift used (for common API).

Source code in matfree/stochtrace.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def nystrom_eigh(
    eigenvalues_rtol: float | None = None,
    leverage_rtol: float | None = None,
    symmetrize_input: bool = True,
):
    """Construct a Nystrom approximation of an operator using a Hermitian eigendecomposition.

    Parameters
    ----------
    eigenvalues_rtol
        A relative tolerance used to determine which eigenvalues are close enough to 0.
    leverage_rtol
        A relative tolerance used in computing the leverage scores to determine which
        test vectors are essential.
    symmetrize_input
        If ``True`` (default), internally symmetrizes before computing the eigendecomposition.

    Returns
    -------
    nystrom
        A function that computes the Nystrom approximation of an operator using a Hermitian eigendecomposition.
        The function has the signature `(matvec_flat, Omega) -> (nystrom_left, downdate, shift)`,
        where `nystrom_left` is a left factor of the Nystrom approximation matrix of shape ``(n, num_samples)``,
        such that `nystrom_left @ nystrom_left.T.conj()` approximates the operator,
        `downdate` is a matrix of shape ``(n, num_samples)`` whose columns are downdate vectors for the Nystrom approximation,
        and `shift=0` is the shift used (for common API).
    """

    def nystrom(matvec_flat, Omega):
        num_samples = Omega.shape[1]
        Y = func.vmap(matvec_flat, in_axes=-1, out_axes=-1)(Omega)
        # select rtol using same heuristic as jax.numpy.linalg.lstsq
        if eigenvalues_rtol is None:
            vals_rtol = np.finfo_eps(Y.dtype) * num_samples
        else:
            vals_rtol = eigenvalues_rtol
        H = Omega.T.conj() @ Y

        # Compute left-square-root of pinv(H)
        if symmetrize_input:
            H = _symmetrize(H)
        H_eigh = linalg.eigh(H)
        vals = H_eigh.eigenvalues
        vecs = H_eigh.eigenvectors
        mask = vals >= vals_rtol * np.abs(vals[-1])
        inv_sqrt_vals = np.where(mask, vals ** (-0.5), 0.0)
        vecs = np.where(mask, vecs, 0.0)
        H_pinv_left_sqrt = vecs * inv_sqrt_vals

        # Compute left-square-root of Nystrom approximation
        nystrom_left = Y @ H_pinv_left_sqrt

        # Compute the leverage scores of each column
        leverage = np.sum(np.abs(vecs) ** 2, axis=1)
        if leverage_rtol is None:
            _leverage_rtol = np.sqrt(np.finfo_eps(leverage.dtype))
        else:
            _leverage_rtol = leverage_rtol
        is_essential = leverage + _leverage_rtol > 1.0

        # Compute downdate Z s.t. B_hat_{-i} = B_hat - outer(Z[:, i], Z[:, i].conj()).
        # Since pinv(P H P') = P pinv(H) P', WLOG take i=k with Hk = H without row/col k.
        # Non-essential k, rank(Hk) = rank(H): B_hat_{-k} = B_hat, so Z[:, k] = 0.
        # Essential k, rank(Hk) = rank(H) - 1: by Albert (1969) Thm. 3,
        #   pinv(H) = [pinv(Hk) 0; 0 0] + a v v'  (v = pinv(H)[:, k], a = 1/pinv(H)[k, k]).
        # So B_hat = Y_{-k} pinv(Hk) Y_{-k}' + a (Yv)(Yv)' = B_hat_{-k} + a (Yv)(Yv)',
        # giving Z[:, k] = sqrt(a) Yv = F L[k, :]' / norm(L[k, :])
        # (F, L are left-sqrt factors of B_hat and pinv(H)).
        # Albert (1969). SIAM J. Appl. Math. 17(2), 434-440. doi:10.1137/0117041
        norms = func.vmap(linalg.vector_norm, in_axes=0)(H_pinv_left_sqrt)
        downdate = (nystrom_left @ H_pinv_left_sqrt.T.conj()) / norms
        downdate = np.where(is_essential, downdate, 0.0)

        return nystrom_left, downdate, np.asarray(0.0).astype(vals.dtype)

    return nystrom

matfree.stochtrace.nystrom_shifted_cholesky(shift: float | None = None, rtol: float | None = None, symmetrize_input: bool = True)

Construct a Nystrom approximation of a shifted operator using a Cholesky decomposition.

Parameters:

Name Type Description Default
shift float | None

A small positive shift to add to the operator to ensure the resulting operator is positive definite for Cholesky decomposition. If not provided, the rtol is used to compute the shift.

None
rtol float | None

A relative tolerance used in computing the shift.

None
symmetrize_input bool

If True (default), internally symmetrizes before computing the Cholesky factor.

True

Returns:

Type Description
nystrom

A function that computes the Nystrom approximation of a shifted operator using a Cholesky decomposition. The function has the signature (matvec_flat, Omega) -> (nystrom_left, downdate, shift), where nystrom_left is a left factor of the Nystrom approximation matrix of shape (n, num_samples), such that nystrom_left @ nystrom_left.T.conj() approximates the operator, downdate is a matrix of shape (n, num_samples) whose columns are downdate vectors for the Nystrom approximation, and shift is the shift used.

Source code in matfree/stochtrace.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def nystrom_shifted_cholesky(
    shift: float | None = None, rtol: float | None = None, symmetrize_input: bool = True
):
    """Construct a Nystrom approximation of a shifted operator using a Cholesky decomposition.

    Parameters
    ----------
    shift
        A small positive shift to add to the operator to ensure the resulting operator
        is positive definite for Cholesky decomposition.
        If not provided, the `rtol` is used to compute the shift.
    rtol
        A relative tolerance used in computing the shift.
    symmetrize_input
        If ``True`` (default), internally symmetrizes before computing the Cholesky factor.

    Returns
    -------
    nystrom
        A function that computes the Nystrom approximation of a shifted operator using a Cholesky decomposition.
        The function has the signature `(matvec_flat, Omega) -> (nystrom_left, downdate, shift)`,
        where `nystrom_left` is a left factor of the Nystrom approximation matrix of shape ``(n, num_samples)``,
        such that `nystrom_left @ nystrom_left.T.conj()` approximates the operator,
        `downdate` is a matrix of shape ``(n, num_samples)`` whose columns are downdate vectors for the Nystrom approximation,
        and `shift` is the shift used.
    """

    def nystrom(matvec_flat, Omega):
        n = Omega.shape[0]
        Y = func.vmap(matvec_flat, in_axes=-1, out_axes=-1)(Omega)
        Y_norm = linalg.vector_norm(Y)
        if shift is None:
            shift_rtol = np.finfo_eps(Y_norm.dtype) if rtol is None else rtol
            mu = shift_rtol * Y_norm / n**0.5
        else:
            mu = shift
        Y_shifted = Y + mu * Omega
        H = Omega.T.conj() @ Y_shifted

        # Compute left-square-root of inv(H)
        if symmetrize_input:
            H = _symmetrize(H)
        H_cholu = linalg.cholesky(H).T.conj()
        Id = np.eye(H_cholu.shape[0], dtype=H_cholu.dtype)
        H_inv_left_sqrt = linalg.solve_triangular(H_cholu, Id)

        # Compute left-square-root of Nystrom approximation
        nystrom_right = linalg.solve_triangular(H_cholu, Y_shifted.T.conj(), trans=2)
        nystrom_left = nystrom_right.T.conj()

        norms = func.vmap(linalg.vector_norm, in_axes=0)(H_inv_left_sqrt)
        downdate = linalg.solve_triangular(H_cholu, nystrom_right).T.conj()
        downdate = downdate / norms[None, :]

        return nystrom_left, downdate, mu

    return nystrom

matfree.stochtrace.sampler_normal(*args_like, num)

Construct a function that samples from a standard-normal distribution.

Source code in matfree/stochtrace.py
618
619
620
def sampler_normal(*args_like, num):
    """Construct a function that samples from a standard-normal distribution."""
    return _sampler_from_jax_random(prng.normal, *args_like, num=num)

matfree.stochtrace.sampler_signs(*args_like, num)

Construct a function that samples signs uniformly.

For real dtypes, this samples from a Rademacher distribution (uniformly over {-1, 1}). For complex dtypes, this samples from a Steinhaus distribution on the complex unit circle.

Source code in matfree/stochtrace.py
623
624
625
626
627
628
def sampler_signs(*args_like, num):
    """Construct a function that samples signs uniformly.

    For real dtypes, this samples from a Rademacher distribution (uniformly over `{-1, 1}`). For complex dtypes, this samples from a Steinhaus distribution on the complex unit circle.
    """
    return _sampler_from_jax_random(_uniform_signs, *args_like, num=num)

matfree.stochtrace.sampler_sphere(*args_like, num)

Construct a function that samples from a unit sphere scaled to have identity covariance.

Source code in matfree/stochtrace.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
def sampler_sphere(*args_like, num):
    """Construct a function that samples from a unit sphere scaled to have identity covariance."""
    x_flat, unflatten = tree.ravel_pytree(*args_like)
    dtype = x_flat.dtype
    rdtype = dtype.type(0).real.dtype
    n = x_flat.shape[0]
    sqrtn = np.sqrt(n).astype(rdtype)

    def sample(key):
        samples = prng.normal(key, shape=(num, n), dtype=dtype)
        return func.vmap(lambda x: unflatten(x * (sqrtn / linalg.vector_norm(x))))(
            samples
        )

    return sample