Skip to content

matfree.eig

matfree.eig

Matrix-free eigenvalue and singular-value analysis.

Examples:

>>> import jax.random
>>> from matfree import decomp
>>>
>>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10))
>>> A = M + M.T
>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
>>>
>>> # Replace tridiagonalisation with bidiagonalisation and eigh with svd
>>> # to compute a partial SVD instead of a partial eigendecomposition.
>>> tridiag = decomp.tridiag_sym(4)
>>> eigh_fun = eigh_partial(tridiag)
>>> eigvals, eigvecs = eigh_fun(lambda s: A @ s, v)
>>> print(eigvals.shape)
(4,)
>>> print(eigvecs.shape)
(4, 10)

matfree.eig.eig_partial(hessenberg: Callable) -> Callable

Compute a partial eigendecomposition of an arbitrary square matrix via Hessenberg factorisation.

Supports complex-valued matrices if the Hessenberg factorisation does.

Parameters:

Name Type Description Default
hessenberg Callable

An implementation of Hessenberg factorisation. For example, the output of decomp.hessenberg.

required
Source code in matfree/eig.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def eig_partial(hessenberg: Callable) -> Callable:
    """Compute a partial eigendecomposition of an arbitrary square matrix via Hessenberg factorisation.

    Supports complex-valued matrices if the Hessenberg factorisation does.

    Parameters
    ----------
    hessenberg:
        An implementation of Hessenberg factorisation.
        For example, the output of
        [decomp.hessenberg][matfree.decomp.hessenberg].

    """

    def eig(Av: Callable, v0: Array, *parameters):
        # Flatten in- and outputs
        Av_flat, flattened = _partial_and_flatten_matvec(Av, v0, *parameters)
        _, (v0_flat, v_unravel) = flattened

        # Call the flattened eig
        vals, vecs = eig_flat(Av_flat, v0_flat)

        # Unravel the eigenvectors
        vecs = func.vmap(v_unravel)(vecs)
        return vals, vecs

    def eig_flat(Av: Callable, v0: Array):
        # Factorise the matrix
        Q, H, *_ = hessenberg(Av, v0)

        # Compute eig of factorisation (Q is (k, n) -> rows are Krylov vectors)
        vals, vecs = linalg.eig(H)
        vecs = vecs.T @ Q
        return vals, vecs

    return eig

matfree.eig.eigh_partial(tridiag_sym: Callable) -> Callable

Compute a partial eigendecomposition \(A \approx V \Lambda V^\top\) for symmetric/Hermitian matrices.

Supports complex-valued (Hermitian) matrices if the tridiagonalisation does.

Parameters:

Name Type Description Default
tridiag_sym Callable

An implementation of tridiagonalization. For example, the output of decomp.tridiag_sym.

required
Source code in matfree/eig.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def eigh_partial(tridiag_sym: Callable) -> Callable:
    r"""Compute a partial eigendecomposition $A \approx V \Lambda V^\top$ for symmetric/Hermitian matrices.

    Supports complex-valued (Hermitian) matrices if the tridiagonalisation does.

    Parameters
    ----------
    tridiag_sym:
        An implementation of tridiagonalization.
        For example, the output of
        [decomp.tridiag_sym][matfree.decomp.tridiag_sym].

    """

    def eigh(Av: Callable, v0: Array, *parameters):
        # Flatten in- and outputs
        Av_flat, flattened = _partial_and_flatten_matvec(Av, v0, *parameters)
        _, (v0_flat, v_unravel) = flattened

        # Call the flattened eigh
        vals, vecs = eigh_flat(Av_flat, v0_flat)

        # Unravel the eigenvectors
        vecs = func.vmap(v_unravel)(vecs)
        return vals, vecs

    def eigh_flat(Av: Callable, v0: Array):
        # Factorise the matrix
        Q, H, *_ = tridiag_sym(Av, v0)

        # Compute eigh of factorisation (Q is (k, n) -> rows are Krylov vectors)
        vals, vecs = linalg.eigh(H)
        vecs = vecs.T @ Q
        return vals, vecs

    return eigh

matfree.eig.svd_partial(bidiag: Callable) -> Callable

Compute a partial SVD \(A \approx U \Sigma V^\top\) via bidiagonalisation.

Supports complex-valued matrices if the bidiagonalisation does.

Parameters:

Name Type Description Default
bidiag Callable

An implementation of bidiagonalisation. For example, the output of decomp.bidiag. Note how this function assumes that the bidiagonalisation materialises the bidiagonal matrix.

required
Source code in matfree/eig.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def svd_partial(bidiag: Callable) -> Callable:
    r"""Compute a partial SVD $A \approx U \Sigma V^\top$ via bidiagonalisation.

    Supports complex-valued matrices if the bidiagonalisation does.

    Parameters
    ----------
    bidiag:
        An implementation of bidiagonalisation.
        For example, the output of
        [decomp.bidiag][matfree.decomp.bidiag].
        Note how this function assumes that the bidiagonalisation
        materialises the bidiagonal matrix.

    """

    def svd(Av: Callable, v0: Array, *parameters):
        # Flatten in- and outputs
        Av_flat, flattened = _partial_and_flatten_matvec(Av, v0, *parameters)
        (_u0_flat, u_unravel), (v0_flat, v_unravel) = flattened

        # Call the flattened SVD
        ut, s, vt = svd_flat(Av_flat, v0_flat)

        # Unravel the singular vectors
        ut_tree = func.vmap(u_unravel)(ut)
        vt_tree = func.vmap(v_unravel)(vt)
        return ut_tree, s, vt_tree

    def svd_flat(Av: Callable, v0: Array):
        # Factorise the matrix
        (u, v), B, *_ = bidiag(Av, v0)

        # Compute SVD of factorisation
        U, S, Vt = linalg.svd(B, full_matrices=False)

        # Combine orthogonal transformations (u, v are (k, n) -> rows are Krylov vectors)
        return U.T @ u, S, Vt @ v

    return svd