Implement uncertainty quantification for trace estimation¶
Computing higher moments of trace-estimates can easily be turned into uncertainty quantification.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from matfree import stochtrace
from matfree import stochtrace
In [3]:
Copied!
A = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
A = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
In [4]:
Copied!
def matvec(x):
"""Evaluate a matrix-vector product."""
return A.T @ (A @ x) + x
def matvec(x):
"""Evaluate a matrix-vector product."""
return A.T @ (A @ x) + x
In [5]:
Copied!
x_like = jnp.ones((6,))
num_samples = 10_000
x_like = jnp.ones((6,))
num_samples = 10_000
Higher moments¶
Trace estimation involves estimating expected values of random variables. Sometimes, second and higher moments of a random variable are interesting.
In [6]:
Copied!
normal = stochtrace.sampler_normal(x_like, num=num_samples)
integrand = stochtrace.integrand_trace()
integrand = stochtrace.integrand_wrap_moments(integrand, [1, 2])
estimator = stochtrace.estimator(integrand, sampler=normal)
first, second = estimator(matvec, jax.random.PRNGKey(1))
normal = stochtrace.sampler_normal(x_like, num=num_samples)
integrand = stochtrace.integrand_trace()
integrand = stochtrace.integrand_wrap_moments(integrand, [1, 2])
estimator = stochtrace.estimator(integrand, sampler=normal)
first, second = estimator(matvec, jax.random.PRNGKey(1))
For normally-distributed base-samples, we know that the variance is twice the squared Frobenius norm.
In [7]:
Copied!
print(second - first**2)
print(2 * jnp.linalg.norm(A.T @ A + jnp.eye(6), ord="fro") ** 2)
print(second - first**2)
print(2 * jnp.linalg.norm(A.T @ A + jnp.eye(6), ord="fro") ** 2)
322.09515 321.78638
Uncertainty quantification¶
Variance estimation leads to uncertainty quantification: The variance of the estimator is equal to the variance of the random variable divided by the number of samples.
In [8]:
Copied!
variance = (second - first**2) / num_samples
print(variance)
variance = (second - first**2) / num_samples
print(variance)
0.032209516