Compute log-determinants with stochastic Lanczos quadrature¶
Log-determinant estimation can be implemented with stochastic Lanczos quadrature, which can be loosely interpreted as an extension of Hutchinson's trace estimator.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from matfree import hutchinson, lanczos
from matfree import hutchinson, lanczos
Set up a matrix.
In [3]:
Copied!
nhidden, nrows = 6, 5
A = jnp.reshape(jnp.arange(1.0, 1.0 + nhidden * nrows), (nhidden, nrows))
A /= nhidden * nrows
nhidden, nrows = 6, 5
A = jnp.reshape(jnp.arange(1.0, 1.0 + nhidden * nrows), (nhidden, nrows))
A /= nhidden * nrows
In [4]:
Copied!
def matvec(x):
"""Compute a matrix-vector product."""
return A.T @ (A @ x) + x
def matvec(x):
"""Compute a matrix-vector product."""
return A.T @ (A @ x) + x
In [5]:
Copied!
x_like = jnp.ones((nrows,), dtype=float) # use to determine shapes of vectors
x_like = jnp.ones((nrows,), dtype=float) # use to determine shapes of vectors
Estimate log-determinants with stochastic Lanczos quadrature.
In [6]:
Copied!
order = 3
problem = lanczos.integrand_spd_logdet(order, matvec)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
print(logdet)
order = 3
problem = lanczos.integrand_spd_logdet(order, matvec)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
print(logdet)
2.3622568
For comparison:
In [7]:
Copied!
print(jnp.linalg.slogdet(A.T @ A + jnp.eye(nrows)))
print(jnp.linalg.slogdet(A.T @ A + jnp.eye(nrows)))
(Array(1., dtype=float32), Array(2.4568148, dtype=float32))
We can compute the log determinant of a matrix of the form $M = B^\top B$, purely based on arithmetic with $B$; no need to assemble $M$:
In [8]:
Copied!
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows**2), (nrows, nrows))
A += jnp.eye(nrows)
A /= nrows**2
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows**2), (nrows, nrows))
A += jnp.eye(nrows)
A /= nrows**2
In [9]:
Copied!
def matvec_right(x):
"""Compute a matrix-vector product."""
return A @ x
def matvec_right(x):
"""Compute a matrix-vector product."""
return A @ x
In [10]:
Copied!
def vecmat_left(x):
"""Compute a vector-matrix product."""
return x @ A
def vecmat_left(x):
"""Compute a vector-matrix product."""
return x @ A
In [11]:
Copied!
order = 3
problem = lanczos.integrand_product_logdet(order, matvec_right, vecmat_left)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
print(logdet)
order = 3
problem = lanczos.integrand_product_logdet(order, matvec_right, vecmat_left)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
print(logdet)
-22.779821
For comparison:
In [12]:
Copied!
print(jnp.linalg.slogdet(A.T @ A))
print(jnp.linalg.slogdet(A.T @ A))
(Array(1., dtype=float32), Array(-21.758816, dtype=float32))