Skip to content

matfree.decomp

matfree.decomp

Matrix-free matrix decompositions.

This module includes various Lanczos-decompositions of matrices (tri-diagonal, bi-diagonal, etc.).

For stochastic Lanczos quadrature and matrix-function-vector products, see matfree.funm.

matfree.decomp.bidiag(num_matvecs: int, /, materialize: bool = True)

Construct an implementation of bidiagonalisation.

Uses pre-allocation and full reorthogonalisation.

Works for arbitrary matrices. No symmetry required.

Decompose a matrix into a product of orthogonal-bidiagonal-orthogonal matrices. Use this algorithm for approximate singular value decompositions.

Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.

A note about differentiability

Unlike tridiag_sym or hessenberg, this function's reverse-mode derivatives are very efficient. Custom gradients for bidiagonalisation are a work in progress, and if you need to differentiate the decompositions, consider using tridiag_sym for the time being.

Source code in matfree/decomp.py
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
def bidiag(num_matvecs: int, /, materialize: bool = True):
    """Construct an implementation of **bidiagonalisation**.

    Uses pre-allocation and full reorthogonalisation.

    Works for **arbitrary matrices**. No symmetry required.

    Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices.
    Use this algorithm for approximate **singular value** decompositions.

    Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.

    ??? note "A note about differentiability"
        Unlike [tridiag_sym][matfree.decomp.tridiag_sym] or
        [hessenberg][matfree.decomp.hessenberg], this function's reverse-mode
        derivatives are very efficient. Custom gradients for bidiagonalisation
        are a work in progress, and if you need to differentiate the decompositions,
        consider using [tridiag_sym][matfree.decomp.tridiag_sym] for the time being.

    """

    def estimate(Av: Callable, v0, *parameters):
        # Infer the size of A from v0
        (ncols,) = np.shape(v0)
        w0_like = func.eval_shape(Av, v0, *parameters)
        (nrows,) = np.shape(w0_like)

        # Complain if the shapes don't match
        max_num_matvecs = min(nrows, ncols)
        if num_matvecs > max_num_matvecs or num_matvecs < 0:
            msg = _error_num_matvecs(num_matvecs, maxval=min(nrows, ncols), minval=0)
            raise ValueError(msg)

        v0_norm, length = _normalise(v0)
        init_val = init(v0_norm, nrows=nrows, ncols=ncols)

        if num_matvecs == 0:
            uk_all_T, J, vk_all, (beta, vk) = extract(init_val)
            return _DecompResult(
                Q_tall=(uk_all_T, vk_all.T),
                J_small=J,
                residual=beta * vk,
                init_length_inv=1 / length,
            )

        def body_fun(_, s):
            return step(Av, s, *parameters)

        result = control_flow.fori_loop(
            0, num_matvecs, body_fun=body_fun, init_val=init_val
        )
        uk_all_T, J, vk_all, (beta, vk) = extract(result)
        return _DecompResult(
            Q_tall=(uk_all_T, vk_all.T),
            J_small=J,
            residual=beta * vk,
            init_length_inv=1 / length,
        )

    class State(containers.NamedTuple):
        i: int
        Us: Array
        Vs: Array
        alphas: Array
        betas: Array
        beta: Array
        vk: Array

    def init(init_vec: Array, *, nrows, ncols) -> State:
        alphas = np.zeros((num_matvecs,))
        betas = np.zeros((num_matvecs,))
        Us = np.zeros((num_matvecs, nrows))
        Vs = np.zeros((num_matvecs, ncols))
        v0, _ = _normalise(init_vec)
        return State(0, Us, Vs, alphas, betas, np.zeros(()), v0)

    def step(Av, state: State, *parameters) -> State:
        i, Us, Vs, alphas, betas, beta, vk = state
        Vs = Vs.at[i].set(vk)
        betas = betas.at[i].set(beta)

        # Use jax.vjp to evaluate the vector-matrix product
        Av_eval, vA = func.vjp(lambda v: Av(v, *parameters), vk)
        uk = Av_eval - beta * Us[i - 1]
        uk, alpha = _normalise(uk)
        uk, *_ = _gram_schmidt_classical(uk, Us)  # full reorthogonalisation
        Us = Us.at[i].set(uk)
        alphas = alphas.at[i].set(alpha)

        (vA_eval,) = vA(uk)
        vk = vA_eval - alpha * vk
        vk, beta = _normalise(vk)
        vk, *_ = _gram_schmidt_classical(vk, Vs)  # full reorthogonalisation

        return State(i + 1, Us, Vs, alphas, betas, beta, vk)

    def extract(state: State, /):
        _, uk_all, vk_all, alphas, betas, beta, vk = state

        if materialize:
            B = _todense_bidiag(alphas, betas[1:])
            return uk_all.T, B, vk_all, (beta, vk)

        return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk)

    def _gram_schmidt_classical(vec, vectors):  # Gram-Schmidt
        vec, coeffs = control_flow.scan(_gram_schmidt_classical_step, vec, xs=vectors)
        vec, length = _normalise(vec)
        return vec, length, coeffs

    def _gram_schmidt_classical_step(vec1, vec2):
        coeff = linalg.inner(vec1, vec2)
        vec_ortho = vec1 - coeff * vec2
        return vec_ortho, coeff

    def _normalise(vec):
        length = linalg.vector_norm(vec)
        return vec / length, length

    def _todense_bidiag(d, e):
        diag = linalg.diagonal_matrix(d)
        offdiag = linalg.diagonal_matrix(e, 1)
        return diag + offdiag

    return estimate

matfree.decomp.hessenberg(num_matvecs, /, *, reortho: str, custom_vjp: bool = True, reortho_vjp: str = 'match')

Construct a Hessenberg-factorisation via the Arnoldi iteration.

Uses pre-allocation, and full reorthogonalisation if reortho is set to "full". It tends to be a good idea to use full reorthogonalisation.

This algorithm works for arbitrary matrices.

Setting custom_vjp to True implies using efficient, numerically stable gradients of the Arnoldi iteration according to what has been proposed by Krämer et al. (2024). These gradients are exact, so there is little reason not to use them. If you use this configuration, please consider citing Krämer et al. (2024; bibtex below).

BibTex for Krämer et al. (2024)
@article{kraemer2024gradients,
    title={Gradients of functions of large matrices},
    author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
    Roy, Hrittik and Hauberg, S{\o}ren},
    journal={arXiv preprint arXiv:2405.17277},
    year={2024}
}
Source code in matfree/decomp.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def hessenberg(
    num_matvecs, /, *, reortho: str, custom_vjp: bool = True, reortho_vjp: str = "match"
):
    r"""Construct a **Hessenberg-factorisation** via the Arnoldi iteration.

    Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`.
    It tends to be a good idea to use full reorthogonalisation.

    This algorithm works for **arbitrary matrices**.

    Setting `custom_vjp` to `True` implies using efficient, numerically stable
    gradients of the Arnoldi iteration according to what has been proposed by
    Krämer et al. (2024).
    These gradients are exact, so there is little reason not to use them.
    If you use this configuration,
    please consider citing Krämer et al. (2024; bibtex below).

    ??? note "BibTex for Krämer et al. (2024)"
        ```bibtex
        @article{kraemer2024gradients,
            title={Gradients of functions of large matrices},
            author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
            Roy, Hrittik and Hauberg, S{\o}ren},
            journal={arXiv preprint arXiv:2405.17277},
            year={2024}
        }
        ```
    """
    reortho_expected = ["none", "full"]
    if reortho not in reortho_expected:
        msg = f"Unexpected input for {reortho}: either of {reortho_expected} expected."
        raise TypeError(msg)

    def estimate(matvec, v, *params):
        matvec_convert, aux_args = func.closure_convert(matvec, v, *params)
        return _estimate(matvec_convert, v, *params, *aux_args)

    def _estimate(matvec_convert: Callable, v, *params):
        reortho_ = reortho_vjp if reortho_vjp != "match" else reortho_vjp
        return _hessenberg_forward(
            matvec_convert, num_matvecs, v, *params, reortho=reortho_
        )

    def estimate_fwd(matvec_convert: Callable, v, *params):
        outputs = _estimate(matvec_convert, v, *params)
        return outputs, (outputs, params)

    def estimate_bwd(matvec_convert: Callable, cache, vjp_incoming):
        (Q, H, r, c), params = cache
        dQ, dH, dr, dc = vjp_incoming

        return _hessenberg_adjoint(
            matvec_convert,
            *params,
            Q=Q,
            H=H,
            r=r,
            c=c,
            dQ=dQ,
            dH=dH,
            dr=dr,
            dc=dc,
            reortho=reortho,
        )

    if custom_vjp:
        _estimate = func.custom_vjp(_estimate, nondiff_argnums=(0,))
        _estimate.defvjp(estimate_fwd, estimate_bwd)  # type: ignore
    return estimate

matfree.decomp.tridiag_sym(num_matvecs: int, /, *, materialize: bool = True, reortho: str = 'full', custom_vjp: bool = True)

Construct an implementation of tridiagonalisation.

Uses pre-allocation, and full reorthogonalisation if reortho is set to "full". It tends to be a good idea to use full reorthogonalisation.

This algorithm assumes a symmetric matrix.

Decompose a matrix into a product of orthogonal-tridiagonal-orthogonal matrices. Use this algorithm for approximate eigenvalue decompositions.

Setting custom_vjp to True implies using efficient, numerically stable gradients of the Lanczos iteration according to what has been proposed by Krämer et al. (2024). These gradients are exact, so there is little reason not to use them. If you use this configuration, please consider citing Krämer et al. (2024; bibtex below).

BibTex for Krämer et al. (2024)
@article{kraemer2024gradients,
    title={Gradients of functions of large matrices},
    author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
    Roy, Hrittik and Hauberg, S{\o}ren},
    journal={arXiv preprint arXiv:2405.17277},
    year={2024}
}

Parameters:

Name Type Description Default
num_matvecs int

The number of matrix-vector products aka the depth of the Krylov space. The deeper the Krylov space, the more accurate the factorisation tends to be. However, the computational complexity increases linearly with the number of matrix-vector products.

required
materialize bool

The value of this flag indicates whether the tridiagonal matrix should be returned in a sparse format (which means, as a tuple of diagonas) or as a dense matrix. The dense matrix is helpful if different decompositions should be used interchangeably. The sparse representation requires less memory.

True
reortho str

The value of this parameter indicates whether to reorthogonalise the basis vectors during the forward pass. Reorthogonalisation makes the forward pass more expensive, but helps (significantly) with numerical stability.

'full'
custom_vjp bool

The value of this flag indicates whether to use a custom vector-Jacobian product as proposed by Krämer et al. (2024; bibtex above). Generally, using a custom VJP tends to be a good idea. However, due to JAX's mechanics, a custom VJP precludes the use of forward-mode differentiation (see here), so don't use a custom VJP if you need forward-mode differentiation.

True

Returns:

Type Description
decompose

A decomposition function that maps (matvec, vector, *params) to the decomposition. The decomposition is a tuple of (nested) arrays. The first element is the Krylov basis, the second element represents the tridiagonal matrix (how it is represented depends on the value of ``materialize''), the third element is the residual, and the fourth element is the (inverse of the) length of the initial vector.

Source code in matfree/decomp.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def tridiag_sym(
    num_matvecs: int,
    /,
    *,
    materialize: bool = True,
    reortho: str = "full",
    custom_vjp: bool = True,
):
    r"""Construct an implementation of **tridiagonalisation**.

    Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`.
    It tends to be a good idea to use full reorthogonalisation.

    This algorithm assumes a **symmetric matrix**.

    Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices.
    Use this algorithm for approximate **eigenvalue** decompositions.

    Setting `custom_vjp` to `True` implies using efficient, numerically stable
    gradients of the Lanczos iteration according to what has been proposed by
    Krämer et al. (2024).
    These gradients are exact, so there is little reason not to use them.
    If you use this configuration, please consider
    citing Krämer et al. (2024; bibtex below).

    ??? note "BibTex for Krämer et al. (2024)"
        ```bibtex
        @article{kraemer2024gradients,
            title={Gradients of functions of large matrices},
            author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
            Roy, Hrittik and Hauberg, S{\o}ren},
            journal={arXiv preprint arXiv:2405.17277},
            year={2024}
        }
        ```

    Parameters
    ----------
    num_matvecs
        The number of matrix-vector products aka the depth of the Krylov space.
        The deeper the Krylov space, the more accurate the factorisation tends to be.
        However, the computational complexity increases linearly
        with the number of matrix-vector products.
    materialize
        The value of this flag indicates whether the tridiagonal matrix
        should be returned in a sparse format (which means, as a tuple of diagonas)
        or as a dense matrix.
        The dense matrix is helpful if different decompositions should be used
        interchangeably. The sparse representation requires less memory.
    reortho
        The value of this parameter indicates whether to reorthogonalise the
        basis vectors during the forward pass.
        Reorthogonalisation makes the forward pass more expensive, but helps
        (significantly) with numerical stability.
    custom_vjp
        The value of this flag indicates whether to use a custom vector-Jacobian
        product as proposed by Krämer et al. (2024; bibtex above).
        Generally, using a custom VJP tends to be a good idea.
        However, due to JAX's mechanics, a custom VJP precludes the use of forward-mode
        differentiation
        ([see here](https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html)),
        so don't use a custom VJP if you need forward-mode differentiation.

    Returns
    -------
    decompose
        A decomposition function that maps
        ``(matvec, vector, *params)`` to the decomposition.
        The decomposition is a tuple of (nested) arrays.
        The first element is the Krylov basis,
        the second element represents the tridiagonal matrix
        (how it is represented depends on the value of ``materialize''),
        the third element is
        the residual, and the fourth element is
        the (inverse of the) length of the initial vector.
    """
    if reortho == "full":
        return _tridiag_reortho_full(
            num_matvecs, custom_vjp=custom_vjp, materialize=materialize
        )
    if reortho == "none":
        return _tridiag_reortho_none(
            num_matvecs, custom_vjp=custom_vjp, materialize=materialize
        )

    msg = f"reortho={reortho} unsupported. Choose eiter {'full', 'none'}."
    raise ValueError(msg)