Combine trace estimation with control variates¶
Here is how to implement control variates.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from matfree import stochtrace
from matfree import stochtrace
Create a matrix to whose trace/diagonal to approximate.
In [3]:
Copied!
nrows, ncols = 4, 4
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows * ncols), (nrows, ncols))
nrows, ncols = 4, 4
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows * ncols), (nrows, ncols))
Set up the sampler.
In [4]:
Copied!
x_like = jnp.ones((ncols,), dtype=float)
sample_fun = stochtrace.sampler_signs(x_like, num=10_000)
x_like = jnp.ones((ncols,), dtype=float)
sample_fun = stochtrace.sampler_signs(x_like, num=10_000)
First, compute the diagonal.
In [5]:
Copied!
integrand = stochtrace.monte_carlo_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
diagonal_ctrl = estimate(lambda v: A @ v, jax.random.PRNGKey(1))
integrand = stochtrace.monte_carlo_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
diagonal_ctrl = estimate(lambda v: A @ v, jax.random.PRNGKey(1))
Then, compute trace and diagonal jointly using the estimate of the diagonal as a control variate.
In [6]:
Copied!
def matvec_ctrl(v):
"""Evaluate a matrix-vector product with a control variate."""
return A @ v - diagonal_ctrl * v
def matvec_ctrl(v):
"""Evaluate a matrix-vector product with a control variate."""
return A @ v - diagonal_ctrl * v
In [7]:
Copied!
integrand = stochtrace.monte_carlo_trace_and_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
trace_and_diagonal = estimate(matvec_ctrl, jax.random.PRNGKey(2))
trace, diagonal = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
integrand = stochtrace.monte_carlo_trace_and_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
trace_and_diagonal = estimate(matvec_ctrl, jax.random.PRNGKey(2))
trace, diagonal = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
We can, of course, compute it without a control variate as well.
In [8]:
Copied!
integrand = stochtrace.monte_carlo_trace_and_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
trace_and_diagonal = estimate(lambda v: A @ v, jax.random.PRNGKey(2))
trace_ref, diagonal_ref = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
integrand = stochtrace.monte_carlo_trace_and_diagonal()
estimate = stochtrace.estimator_monte_carlo(integrand, sample_fun)
trace_and_diagonal = estimate(lambda v: A @ v, jax.random.PRNGKey(2))
trace_ref, diagonal_ref = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
Compare the results. First, the diagonal.
In [9]:
Copied!
print("True value:", jnp.diag(A))
print("Control variate:", diagonal_ctrl, jnp.linalg.norm(jnp.diag(A) - diagonal_ctrl))
print("Approximation:", diagonal_ref, jnp.linalg.norm(jnp.diag(A) - diagonal_ref))
print(
"Control-variate approximation:",
diagonal + diagonal_ctrl,
jnp.linalg.norm(jnp.diag(A) - diagonal - diagonal_ctrl),
)
print("True value:", jnp.diag(A))
print("Control variate:", diagonal_ctrl, jnp.linalg.norm(jnp.diag(A) - diagonal_ctrl))
print("Approximation:", diagonal_ref, jnp.linalg.norm(jnp.diag(A) - diagonal_ref))
print(
"Control-variate approximation:",
diagonal + diagonal_ctrl,
jnp.linalg.norm(jnp.diag(A) - diagonal - diagonal_ctrl),
)
True value: [ 1. 6. 11. 16.] Control variate: [ 0.9912 6.0267997 10.912 16.0336 ] 0.09832937 Approximation: [ 0.92399997 5.7023997 10.6608 15.613199 ] 0.5991773 Control-variate approximation: [ 0.9240003 5.7023997 10.6608 15.6132 ] 0.59917724
Then, the trace.
In [10]:
Copied!
print("True value:", jnp.trace(A))
print(
"Control variate:",
jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - jnp.sum(diagonal_ctrl)),
)
print("Approximation:", trace_ref, jnp.abs(jnp.trace(A) - trace_ref))
print(
"Control variate approximation:",
trace + jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - trace - jnp.sum(diagonal_ctrl)),
)
print("True value:", jnp.trace(A))
print(
"Control variate:",
jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - jnp.sum(diagonal_ctrl)),
)
print("Approximation:", trace_ref, jnp.abs(jnp.trace(A) - trace_ref))
print(
"Control variate approximation:",
trace + jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - trace - jnp.sum(diagonal_ctrl)),
)
True value: 34.0 Control variate: 33.9636 0.03639984 Approximation: 32.9004 1.0996017
Control variate approximation: 32.900402 1.0995979