Skip to content

Matfree: Matrix-free linear algebra in JAX

Actions status image image image

Randomised and deterministic matrix-free methods for trace estimation, functions of matrices, and matrix factorisations. Matfree builds on JAX.

  • ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
  • ⚡ A stand-alone implementation of stochastic Lanczos quadrature for traces of functions of matrices
  • ⚡ Matrix-decomposition algorithms for large sparse eigenvalue problems: tridiagonalisation, bidiagonalisation, and Hessenberg factorisation via Lanczos and Arnoldi iterations
  • ⚡ Chebyshev, Lanczos, and Arnoldi-based methods for approximating functions of large matrices
  • Gradients of functions of large matrices (like in this paper) via differentiable Lanczos and Arnoldi iterations
  • ⚡ Partial Cholesky preconditioners with and without pivoting
  • ⚡ And many other things

Everything is natively compatible with the rest of JAX: JIT compilation, automatic differentiation, vectorisation, and PyTrees. Let us know what you think about Matfree!

Installation

To install the package, run

pip install matfree

Important: This assumes you already have a working installation of JAX. To install JAX, follow these instructions. To combine Matfree with a CPU version of JAX, run

pip install matfree[cpu]
which is equivalent to combining pip install jax[cpu] with pip install matfree. (But do not only use Matfree on CPU!)

Minimal example

Import Matfree and JAX, and set up a test problem.

>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import stochtrace
>>>
>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
...     return A.T @ (A @ x)
...

Estimate the trace of the matrix:

>>> # Determine the shape of the base-samples
>>> input_like = jnp.zeros((2,), dtype=float)
>>> sampler = stochtrace.sampler_rademacher(input_like, num=10_000)
>>>
>>> # Set Hutchinson's method up to compute the traces
>>> # (instead of, e.g., diagonals)
>>> integrand = stochtrace.integrand_trace()
>>>
>>> # Compute an estimator
>>> estimate = stochtrace.estimator(integrand, sampler)
>>>
>>> # Estimate
>>> key = jax.random.PRNGKey(1)
>>> trace = estimate(matvec, key)
>>>
>>> print(trace)
508.9
>>>
>>> # for comparison:
>>> print((jnp.trace(A.T @ A)))
506.0

Tutorials

Find many more tutorials in Matfree's documentation.

These tutorials include, among other things:

  • Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions.
  • Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature.
  • Control variates: Use control variates and multilevel schemes to reduce variances.
  • Higher moments and UQ: Compute means, variances, and other moments simultaneously.
  • Vector calculus: Use matrix-free linear algebra to implement vector calculus.
  • Low-memory trace estimation: Combine Matfree's API with JAX's function transformations for low-memory stochastic trace estimation.

Let us know what you use Matfree for!

Citation

Thank you for using Matfree! If you are using Matfree's differentiable Lanczos or Arnoldi iterations, then you are using the algorithms from this paper. We would appreciate it if you cited the paper as follows:

@article{kraemer2024gradients,
    title={Gradients of functions of large matrices},
    author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and Roy, Hrittik and Hauberg, S{\o}ren},
    journal={arXiv preprint arXiv:2405.17277},
    year={2024}
}

Some of Matfree's docstrings contain additional bibliographic information. For example, the matfree.bounds functions link to bibtex entries for the articles associated with each bound. Go check out the API documentation.