matfree: Matrix-free linear algebra in JAX
Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. Builds on JAX.
- ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
- ⚡ A stand-alone implementation of stochastic Lanczos quadrature
- ⚡ Matrix-decomposition algorithms for large sparse eigenvalue problems
- ⚡ Polynomial methods for approximating functions of large matrices
and many other things. Everything is natively compatible with the rest of JAX: JIT compilation, automatic differentiation, vectorisation, and PyTrees. Let us know what you think about matfree!
Installation
To install the package, run
pip install matfree
Important: This assumes you already have a working installation of JAX.
To install JAX, follow these instructions.
To combine matfree
with a CPU version of JAX, run
pip install matfree[cpu]
pip install jax[cpu]
with pip install matfree
.
(But do not only use matfree on CPU!)
Minimal example
Import matfree and JAX, and set up a test problem.
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson
>>>
>>> jnp.set_printoptions(1)
>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
... return A.T @ (A @ x)
...
Estimate the trace of the matrix:
>>> # Determine the shape of the base-samples
>>> input_like = jnp.zeros((2,), dtype=float)
>>> sampler = hutchinson.sampler_rademacher(input_like, num=10_000)
>>>
>>> # Set Hutchinson's method up to compute the traces
>>> # (instead of, e.g., diagonals)
>>> integrand = hutchinson.integrand_trace(matvec)
>>>
>>> # Compute an estimator
>>> estimate = hutchinson.hutchinson(integrand, sampler)
>>> # Estimate
>>> key = jax.random.PRNGKey(1)
>>> trace = jax.jit(estimate)(key)
>>>
>>> print(trace)
508.9
>>>
>>> # for comparison:
>>> print((jnp.trace(A.T @ A)))
506.0
Tutorials
Find many more tutorials in Matfree's documentation.
These tutorials include, among other things:
- Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions.
- Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature.
- Control variates: Use control variates and multilevel schemes to reduce variances.
- Higher moments and UQ: Compute means, variances, and other moments simultaneously.
- Vector calculus: Use matrix-free linear algebra to implement vector calculus.
Let us know what you use matfree for!