Skip to content

matfree.decomp

matfree.decomp

Matrix-free matrix decompositions.

This module includes various Lanczos-decompositions of matrices (tri-diagonal, bi-diagonal, etc.).

For stochastic Lanczos quadrature and matrix-function-vector products, see matfree.funm.

matfree.decomp.bidiag(num_matvecs: int, /, materialize: bool = True, reortho: str = 'full')

Construct an implementation of bidiagonalisation via the Golub-Kahan algorithm.

Factorise \(A \approx U B V^\top\), where \(U\), \(V\) are orthogonal and \(B\) is bidiagonal. Works for arbitrary real matrices (rectangular, no symmetry required). Does not support complex-valued matrices.

Use this algorithm for approximate singular value decompositions. Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.

BibTex for Golub and Kahan (1965)
@article{golub1965calculating,
    title={Calculating the singular values and pseudo-inverse of a matrix},
    author={Golub, Gene and Kahan, William},
    journal={Journal of the Society for Industrial and Applied Mathematics, Series B: Numerical Analysis},
    volume={2},
    number={2},
    pages={205--224},
    year={1965},
    publisher={SIAM}
}
A note about differentiability

Unlike tridiag_sym or hessenberg, this function's reverse-mode derivatives are not efficient. Custom gradients for bidiagonalisation are a work in progress. In the meantime, if you need to differentiate the decompositions, consider using tridiag_sym instead (if possible).

Source code in matfree/decomp.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
def bidiag(num_matvecs: int, /, materialize: bool = True, reortho: str = "full"):
    r"""Construct an implementation of **bidiagonalisation** via the Golub-Kahan algorithm.

    Factorise $A \approx U B V^\top$, where $U$, $V$ are orthogonal and $B$ is bidiagonal.
    Works for **arbitrary real matrices** (rectangular, no symmetry required).
    Does not support complex-valued matrices.

    Use this algorithm for approximate **singular value** decompositions.
    Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.

    ??? note "BibTex for Golub and Kahan (1965)"
        ```bibtex
        @article{golub1965calculating,
            title={Calculating the singular values and pseudo-inverse of a matrix},
            author={Golub, Gene and Kahan, William},
            journal={Journal of the Society for Industrial and Applied Mathematics, Series B: Numerical Analysis},
            volume={2},
            number={2},
            pages={205--224},
            year={1965},
            publisher={SIAM}
        }
        ```

    ??? note "A note about differentiability"
        Unlike [tridiag_sym][matfree.decomp.tridiag_sym] or
        [hessenberg][matfree.decomp.hessenberg], this function's reverse-mode
        derivatives are not efficient. Custom gradients for bidiagonalisation
        are a work in progress. In the meantime,
        if you need to differentiate the decompositions, consider using
        [tridiag_sym][matfree.decomp.tridiag_sym] instead (if possible).

    """

    def estimate(Av: Callable, v0, *parameters):
        # Flatten v0 and infer the shape of the Av output
        v0_flat, v_unravel = tree.ravel_pytree(v0)
        ncols = v0_flat.shape[0]
        u0_like = func.eval_shape(lambda z: Av(z, *parameters), v0)
        u0_flat_like, u_unravel = func.eval_shape(tree.ravel_pytree, u0_like)
        nrows = u0_flat_like.shape[0]

        def Av_flat(v_f, *p):
            result = Av(v_unravel(v_f), *p)
            result_flat, _ = tree.ravel_pytree(result)
            return result_flat

        # Complain if the shapes don't match
        max_num_matvecs = min(nrows, ncols)
        if num_matvecs > max_num_matvecs or num_matvecs < 0:
            msg = _error_num_matvecs(num_matvecs, maxval=min(nrows, ncols), minval=0)
            raise ValueError(msg)

        v0_norm, length = _normalise(v0_flat)
        init_val = init(v0_norm, nrows=nrows, ncols=ncols)

        if num_matvecs == 0:
            uk_all_T, J, vk_all, (beta, vk) = extract(init_val)
            return _DecompResult(
                Q_tall=(func.vmap(u_unravel)(uk_all_T.T), func.vmap(v_unravel)(vk_all)),
                J_small=J,
                residual=v_unravel(beta * vk),
                init_length_inv=1 / length,
            )

        def body_fun(_, s):
            return step(Av_flat, s, *parameters)

        result = control_flow.fori_loop(
            0, num_matvecs, body_fun=body_fun, init_val=init_val
        )
        uk_all_T, J, vk_all, (beta, vk) = extract(result)
        return _DecompResult(
            Q_tall=(func.vmap(u_unravel)(uk_all_T.T), func.vmap(v_unravel)(vk_all)),
            J_small=J,
            residual=v_unravel(beta * vk),
            init_length_inv=1 / length,
        )

    class State(containers.NamedTuple):
        i: int
        Us: Array
        Vs: Array
        alphas: Array
        betas: Array
        beta: Array
        vk: Array

    def init(init_vec: Array, *, nrows, ncols) -> State:
        alphas = np.zeros((num_matvecs,))
        betas = np.zeros((num_matvecs,))
        Us = np.zeros((num_matvecs, nrows))
        Vs = np.zeros((num_matvecs, ncols))
        v0, _ = _normalise(init_vec)
        return State(0, Us, Vs, alphas, betas, np.zeros(()), v0)

    def step(Av, state: State, *parameters) -> State:
        i, Us, Vs, alphas, betas, beta, vk = state
        Vs = Vs.at[i].set(vk)
        betas = betas.at[i].set(beta)

        # Use jax.vjp to evaluate the vector-matrix product
        Av_eval, vA = func.vjp(lambda v: Av(v, *parameters), vk)
        uk = Av_eval - beta * Us[i - 1]
        if reortho == "full":
            # For some reason, two reorthogonalsiation calls are needed...
            uk = uk - Us.T @ (Us @ uk)
            uk = uk - Us.T @ (Us @ uk)

        uk, alpha = _normalise(uk)
        Us = Us.at[i].set(uk)
        alphas = alphas.at[i].set(alpha)

        (vA_eval,) = vA(uk)
        vk = vA_eval - alpha * vk
        if reortho == "full":
            # For some reason, two reorthogonalsiation calls are needed...
            vk = vk - Vs.T @ (Vs @ vk)
            vk = vk - Vs.T @ (Vs @ vk)

        vk, beta = _normalise(vk)

        return State(i + 1, Us, Vs, alphas, betas, beta, vk)

    def extract(state: State, /):
        _, uk_all, vk_all, alphas, betas, beta, vk = state

        if materialize:
            B = _todense_bidiag(alphas, betas[1:])
            return uk_all.T, B, vk_all, (beta, vk)

        return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk)

    def _normalise(vec):
        length = linalg.vector_norm(vec)
        return vec / length, length

    def _todense_bidiag(d, e):
        diag = linalg.diagonal_matrix(d)
        offdiag = linalg.diagonal_matrix(e, 1)
        return diag + offdiag

    return estimate

matfree.decomp.hessenberg(num_matvecs, /, *, reortho: str, custom_vjp: bool = True, reortho_vjp: str = 'match')

Construct a Hessenberg-factorisation via the Arnoldi iteration.

Factorise \(A \approx Q H Q^\top\), where \(Q\) is orthogonal and \(H\) is upper Hessenberg. Works for arbitrary square matrices. Does not support complex-valued matrices.

Setting custom_vjp to True implies using efficient, numerically stable gradients of the Arnoldi iteration which was proposed by Krämer et al. (2024). These gradients are exact, so there is little reason not to use them. If you use this configuration, please cite Krämer et al. (2024):

BibTex for Krämer et al. (2024)
@article{kraemer2024gradients,
    title={Gradients of functions of large matrices},
    author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
    Roy, Hrittik and Hauberg, S{\o}ren},
    journal={arXiv preprint arXiv:2405.17277},
    year={2024}
}
Source code in matfree/decomp.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def hessenberg(
    num_matvecs, /, *, reortho: str, custom_vjp: bool = True, reortho_vjp: str = "match"
):
    r"""Construct a **Hessenberg-factorisation** via the Arnoldi iteration.

    Factorise $A \approx Q H Q^\top$, where $Q$ is orthogonal and $H$ is upper Hessenberg.
    Works for **arbitrary square matrices**. Does not support complex-valued matrices.

    Setting `custom_vjp` to `True` implies using efficient, numerically stable
    gradients of the Arnoldi iteration which was proposed by Krämer et al. (2024).
    These gradients are exact, so there is little reason not to use them.
    If you use this configuration, please cite Krämer et al. (2024):

    ??? note "BibTex for Krämer et al. (2024)"
        ```bibtex
        @article{kraemer2024gradients,
            title={Gradients of functions of large matrices},
            author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
            Roy, Hrittik and Hauberg, S{\o}ren},
            journal={arXiv preprint arXiv:2405.17277},
            year={2024}
        }
        ```
    """
    reortho_expected = ["none", "full"]
    if reortho not in reortho_expected:
        msg = f"Unexpected input for {reortho}: either of {reortho_expected} expected."
        raise TypeError(msg)

    def estimate(matvec, v, *params):
        v_flat, v_unravel = tree.ravel_pytree(v)

        def matvec_flat(v_f, *p):
            return tree.ravel_pytree(matvec(v_unravel(v_f), *p))[0]

        matvec_convert, aux_args = func.closure_convert(matvec_flat, v_flat, *params)
        Q_flat, H, r_flat, c = _estimate(matvec_convert, v_flat, *params, *aux_args)
        Q_tall = func.vmap(v_unravel)(Q_flat.T)
        return _DecompResult(
            Q_tall=Q_tall, J_small=H, residual=v_unravel(r_flat), init_length_inv=c
        )

    def _estimate(matvec_convert: Callable, v, *params):
        return _hessenberg_forward(
            matvec_convert, num_matvecs, v, *params, reortho=reortho_vjp
        )

    def estimate_fwd(matvec_convert: Callable, v, *params):
        outputs = _estimate(matvec_convert, v, *params)
        return outputs, (outputs, params)

    def estimate_bwd(matvec_convert: Callable, cache, vjp_incoming):
        (Q, H, r, c), params = cache
        dQ, dH, dr, dc = vjp_incoming

        return _hessenberg_adjoint(
            matvec_convert,
            *params,
            Q=Q,
            H=H,
            r=r,
            c=c,
            dQ=dQ,
            dH=dH,
            dr=dr,
            dc=dc,
            reortho=reortho,
        )

    if custom_vjp:
        _estimate = func.custom_vjp(_estimate, nondiff_argnums=(0,))
        _estimate.defvjp(estimate_fwd, estimate_bwd)  # type: ignore
    return estimate

matfree.decomp.tridiag_sym(num_matvecs: int, /, *, materialize: bool = True, reortho: str = 'full', custom_vjp: bool = True)

Construct an implementation of tridiagonalisation.

Decompose a real symmetric matrix into a product of orthogonal-tridiagonal-orthogonal matrices. Use this algorithm for approximate eigenvalue decompositions. Does not support complex-valued matrices.

The present implementation allocates all Lanczos vectors before running the algorithm. If reortho is set to "full", it also uses full reorthogonalisation. It is usually a good idea to use full reorthogonalisation. Matrix-free tridiagonalisation uses Lanczos' (1950) algorithm:

BibTex for Lanczos (1950)
@article{lanczos1950iteration,
    title={An iteration method for the solution of the eigenvalue problem of linear differential and integral operators},
    author={Lanczos, Cornelius},
    journal={Journal of research of the National Bureau of Standards},
    volume={45},
    number={4},
    pages={255--282},
    year={1950}
}

Setting custom_vjp to True implies using efficient, numerically stable gradients of the Lanczos iteration which was proposed by Krämer et al. (2024). These gradients are exact, so there is little reason not to use them. If you use this configuration, please cite Krämer et al. (2024):

BibTex for Krämer et al. (2024)
@article{kraemer2024gradients,
    title={Gradients of functions of large matrices},
    author={Kr{\"a}mer, Nicholas and Moreno-Mu{\~n}oz, Pablo and Roy, Hrittik and Hauberg, S{\o}ren},
    journal={Advances in Neural Information Processing Systems},
    volume={37},
    pages={49484--49518},
    year={2024}
}

Parameters:

Name Type Description Default
num_matvecs int

The number of matrix-vector products aka the depth of the Krylov space. The deeper the Krylov space, the more accurate the factorisation tends to be. However, the computational complexity increases linearly with the number of matrix-vector products.

required
materialize bool

The value of this flag indicates whether the tridiagonal matrix should be returned in a sparse format (which means, as a tuple of diagonas) or as a dense matrix. The dense matrix is helpful if different decompositions should be used interchangeably. The sparse representation requires less memory.

True
reortho str

The value of this parameter indicates whether to reorthogonalise the basis vectors during the forward pass. Reorthogonalisation makes the forward pass more expensive, but helps (significantly) with numerical stability.

'full'
custom_vjp bool

The value of this flag indicates whether to use a custom vector-Jacobian product as proposed by Krämer et al. (2024; bibtex above). Generally, using a custom VJP tends to be a good idea. However, due to JAX's mechanics, a custom VJP precludes the use of forward-mode differentiation (see here), so don't use a custom VJP if you need forward-mode differentiation.

True

Returns:

Type Description
decompose

A function (matvec, vector, *params) returning a four-element result (Q_tall, J_small, residual, init_length_inv).

Source code in matfree/decomp.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def tridiag_sym(
    num_matvecs: int,
    /,
    *,
    materialize: bool = True,
    reortho: str = "full",
    custom_vjp: bool = True,
):
    r"""Construct an implementation of **tridiagonalisation**.

    Decompose a real **symmetric** matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices.
    Use this algorithm for approximate **eigenvalue** decompositions.
    Does not support complex-valued matrices.

    The present implementation allocates all Lanczos vectors before running the
    algorithm. If `reortho` is set to `"full"`, it also uses full reorthogonalisation.
    It is usually a good idea to use full reorthogonalisation.
    Matrix-free tridiagonalisation uses Lanczos' (1950) algorithm:

    ??? note "BibTex for Lanczos (1950)"
        ```bibtex
        @article{lanczos1950iteration,
            title={An iteration method for the solution of the eigenvalue problem of linear differential and integral operators},
            author={Lanczos, Cornelius},
            journal={Journal of research of the National Bureau of Standards},
            volume={45},
            number={4},
            pages={255--282},
            year={1950}
        }
        ```

    Setting `custom_vjp` to `True` implies using efficient, numerically stable
    gradients of the Lanczos iteration which was proposed by Krämer et al. (2024).
    These gradients are exact, so there is little reason not to use them.
    If you use this configuration, please cite Krämer et al. (2024):

    ??? note "BibTex for Krämer et al. (2024)"
        ```bibtex
        @article{kraemer2024gradients,
            title={Gradients of functions of large matrices},
            author={Kr{\"a}mer, Nicholas and Moreno-Mu{\~n}oz, Pablo and Roy, Hrittik and Hauberg, S{\o}ren},
            journal={Advances in Neural Information Processing Systems},
            volume={37},
            pages={49484--49518},
            year={2024}
        }
        ```

    Parameters
    ----------
    num_matvecs
        The number of matrix-vector products aka the depth of the Krylov space.
        The deeper the Krylov space, the more accurate the factorisation tends to be.
        However, the computational complexity increases linearly
        with the number of matrix-vector products.
    materialize
        The value of this flag indicates whether the tridiagonal matrix
        should be returned in a sparse format (which means, as a tuple of diagonas)
        or as a dense matrix.
        The dense matrix is helpful if different decompositions should be used
        interchangeably. The sparse representation requires less memory.
    reortho
        The value of this parameter indicates whether to reorthogonalise the
        basis vectors during the forward pass.
        Reorthogonalisation makes the forward pass more expensive, but helps
        (significantly) with numerical stability.
    custom_vjp
        The value of this flag indicates whether to use a custom vector-Jacobian
        product as proposed by Krämer et al. (2024; bibtex above).
        Generally, using a custom VJP tends to be a good idea.
        However, due to JAX's mechanics, a custom VJP precludes the use of forward-mode
        differentiation
        ([see here](https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html)),
        so don't use a custom VJP if you need forward-mode differentiation.

    Returns
    -------
    decompose
        A function ``(matvec, vector, *params)`` returning a four-element result
        ``(Q_tall, J_small, residual, init_length_inv)``.
    """
    if reortho == "full":
        return _tridiag_reortho_full(
            num_matvecs, custom_vjp=custom_vjp, materialize=materialize
        )
    if reortho == "none":
        return _tridiag_reortho_none(
            num_matvecs, custom_vjp=custom_vjp, materialize=materialize
        )

    msg = f"reortho={reortho} unsupported. Choose eiter {'full', 'none'}."
    raise ValueError(msg)