Compute log-determinants with stochastic Lanczos quadrature¶
Log-determinant estimation can be implemented with stochastic Lanczos quadrature, which can be loosely interpreted as an extension of Hutchinson's trace estimator.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from matfree import decomp, funm, stochtrace
from matfree import decomp, funm, stochtrace
Set up a matrix.
In [3]:
Copied!
nhidden, nrows = 6, 5
A = jnp.reshape(jnp.arange(1.0, 1.0 + nhidden * nrows), (nhidden, nrows))
A /= nhidden * nrows
nhidden, nrows = 6, 5
A = jnp.reshape(jnp.arange(1.0, 1.0 + nhidden * nrows), (nhidden, nrows))
A /= nhidden * nrows
In [4]:
Copied!
def matvec(x):
"""Compute a matrix-vector product."""
return A.T @ (A @ x) + x
def matvec(x):
"""Compute a matrix-vector product."""
return A.T @ (A @ x) + x
In [5]:
Copied!
x_like = jnp.ones((nrows,), dtype=float) # use to determine shapes of vectors
x_like = jnp.ones((nrows,), dtype=float) # use to determine shapes of vectors
Estimate log-determinants with stochastic Lanczos quadrature.
In [6]:
Copied!
num_matvecs = 3
tridiag_sym = decomp.tridiag_sym(num_matvecs)
problem = funm.integrand_funm_sym_logdet(tridiag_sym)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator(matvec, jax.random.PRNGKey(1))
print(logdet)
num_matvecs = 3
tridiag_sym = decomp.tridiag_sym(num_matvecs)
problem = funm.integrand_funm_sym_logdet(tridiag_sym)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator(matvec, jax.random.PRNGKey(1))
print(logdet)
2.3622565
For comparison:
In [7]:
Copied!
print(jnp.linalg.slogdet(A.T @ A + jnp.eye(nrows)))
print(jnp.linalg.slogdet(A.T @ A + jnp.eye(nrows)))
SlogdetResult(sign=Array(1., dtype=float32), logabsdet=Array(2.4568148, dtype=float32))
We can compute the log determinant of a matrix of the form $M = B^\top B$, purely based on arithmetic with $B$; no need to assemble $M$:
In [8]:
Copied!
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows**2), (nrows, nrows))
A += jnp.eye(nrows)
A /= nrows**2
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows**2), (nrows, nrows))
A += jnp.eye(nrows)
A /= nrows**2
In [9]:
Copied!
def matvec_half(x):
"""Compute a matrix-vector product."""
return A @ x
def matvec_half(x):
"""Compute a matrix-vector product."""
return A @ x
In [10]:
Copied!
num_matvecs = 3
bidiag = decomp.bidiag(num_matvecs)
problem = funm.integrand_funm_product_logdet(bidiag)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator(matvec_half, jax.random.PRNGKey(1))
print(logdet)
num_matvecs = 3
bidiag = decomp.bidiag(num_matvecs)
problem = funm.integrand_funm_product_logdet(bidiag)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator(matvec_half, jax.random.PRNGKey(1))
print(logdet)
-22.779821
Internally, Matfree uses JAX's vector-Jacobian products to turn the matrix-vector product into a vector-matrix product.
For comparison:
In [11]:
Copied!
print(jnp.linalg.slogdet(A.T @ A))
print(jnp.linalg.slogdet(A.T @ A))
SlogdetResult(sign=Array(1., dtype=float32), logabsdet=Array(-21.758816, dtype=float32))