Compute matrix functions without materializing large matrices¶
Sometimes, we need to compute matrix exponentials, log-determinants, or similar functions of matrices, but our matrices are too big to use functions from scipy.linalg or jax.scipy.linalg. However, matrix-free linear algebra scales to even the largest of matrices. Here is how to use Matfree to compute functions of large matrices.
import functools
import jax
from matfree import decomp, funm
n = 7 # imagine n = 10^5 or larger
key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key, num=2)
large_matrix = jax.random.normal(subkey, shape=(n, n))
The expected value is computed with jax.scipy.linalg.
key, subkey = jax.random.split(key, num=2)
vector = jax.random.normal(subkey, shape=(n,))
expected = jax.scipy.linalg.expm(large_matrix) @ vector
print(expected)
[ 0.5121861 1.0731273 -1.1475035 -1.6931866 0.06646963 -1.1467085 0.66265297]
Instead of using jax.scipy.linalg, we can use matrix-vector products in combination with the Arnoldi iteration to approximate the matrix-function-vector product.
def large_matvec(v):
"""Evaluate a matrix-vector product."""
return large_matrix @ v
num_matvecs = 5
arnoldi = decomp.hessenberg(num_matvecs, reortho="full")
dense_funm = funm.dense_funm_pade_exp()
matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi)
received = matfun_vec(large_matvec, vector)
print(received)
[ 0.5136445 1.0897965 -1.1209555 -1.7069302 0.03098169 -1.1719893 0.67968863]
The matrix-function vector product can be combined with all usual JAX transformations. For example, after fixing the matvec-function as the first argument, we can vectorize the matrix function with jax.vmap and compile it with jax.jit.
matfun_vec = functools.partial(matfun_vec, large_matvec)
key, subkey = jax.random.split(key, num=2)
vector_batch = jax.random.normal(subkey, shape=(5, n)) # a batch of 5 vectors
received = jax.jit(jax.vmap(matfun_vec))(vector_batch)
print(received.shape)
(5, 7)
Talking about function transformations: we can also reverse-mode-differentiate the matrix functions efficiently.
jac = jax.jacrev(matfun_vec)(vector)
print(jac)
[[ 3.68775666e-01 3.48348975e-01 -1.14449523e-01 -3.22446883e-01 3.28712702e-01 -6.60334349e-01 3.08125526e-01] [ 8.88347626e-04 9.77235258e-01 2.68623352e+00 -5.51655173e-01 -1.45154142e+00 -1.11724639e+00 7.45091677e-01] [ 4.17882234e-01 -9.98488367e-01 -3.91192406e-01 8.76782537e-01 -9.65307474e-01 5.19365370e-01 -6.68987870e-01] [ 2.65466452e-01 -8.89071941e-01 -2.17203140e+00 7.52809644e-01 4.79240775e-01 8.03415000e-01 -8.45992625e-01] [-4.26323414e-01 -8.46019328e-01 -2.89584970e+00 1.10395364e-01 2.57722950e+00 1.75358319e+00 -3.07614803e-01] [-1.35615468e-01 -5.94067991e-01 -1.90474641e+00 1.77025393e-01 1.02040839e+00 7.22389579e-01 -3.67944658e-01] [-3.23790073e-01 1.21016252e+00 1.78035736e+00 -1.12524259e+00 -1.80692703e-01 -1.32690465e+00 1.32771575e+00]]
Under the hood, reverse-mode derivatives of Arnoldi- and Lanczos-based matrix functions use the fast algorithm for gradients of the Lanczos and Arnoldi iterations from this paper. Please consider citing it if you use reverse-mode derivatives functions of matrices (a BibTex is here).