Skip to content

matfree: Matrix-free linear algebra in JAX

Actions status image image image

Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. Builds on JAX.

  • ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
  • ⚡ A stand-alone implementation of stochastic Lanczos quadrature
  • ⚡ Matrix-decomposition algorithms for large sparse eigenvalue problems
  • ⚡ Polynomial methods for approximating functions of large matrices

and many other things. Everything is natively compatible with the rest of JAX: JIT compilation, automatic differentiation, vectorisation, and PyTrees. Let us know what you think about matfree!

Installation

To install the package, run

pip install matfree

Important: This assumes you already have a working installation of JAX. To install JAX, follow these instructions. To combine matfree with a CPU version of JAX, run

pip install matfree[cpu]
which is equivalent to combining pip install jax[cpu] with pip install matfree. (But do not only use matfree on CPU!)

Minimal example

Import matfree and JAX, and set up a test problem.

>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson
>>>
>>> jnp.set_printoptions(1)

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
...     return A.T @ (A @ x)
...

Estimate the trace of the matrix:

>>> # Determine the shape of the base-samples
>>> input_like = jnp.zeros((2,), dtype=float)
>>> sampler = hutchinson.sampler_rademacher(input_like, num=10_000)
>>>
>>> # Set Hutchinson's method up to compute the traces
>>> # (instead of, e.g., diagonals)
>>> integrand = hutchinson.integrand_trace(matvec)
>>>
>>> # Compute an estimator
>>> estimate = hutchinson.hutchinson(integrand, sampler)

>>> # Estimate
>>> key = jax.random.PRNGKey(1)
>>> trace = jax.jit(estimate)(key)
>>>
>>> print(trace)
508.9
>>>
>>> # for comparison:
>>> print((jnp.trace(A.T @ A)))
506.0

Tutorials

Find many more tutorials in Matfree's documentation.

These tutorials include, among other things:

  • Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions.
  • Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature.
  • Control variates: Use control variates and multilevel schemes to reduce variances.
  • Higher moments and UQ: Compute means, variances, and other moments simultaneously.
  • Vector calculus: Use matrix-free linear algebra to implement vector calculus.

Let us know what you use matfree for!