Combine trace estimation with control variates¶
Here is how to implement control variates.
In [1]:
Copied!
import jax
import jax.numpy as jnp
import jax
import jax.numpy as jnp
In [2]:
Copied!
from matfree import stochtrace
from matfree import stochtrace
Create a matrix to whose trace/diagonal to approximate.
In [3]:
Copied!
nrows, ncols = 4, 4
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows * ncols), (nrows, ncols))
nrows, ncols = 4, 4
A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows * ncols), (nrows, ncols))
Set up the sampler.
In [4]:
Copied!
x_like = jnp.ones((ncols,), dtype=float)
sample_fun = stochtrace.sampler_normal(x_like, num=10_000)
x_like = jnp.ones((ncols,), dtype=float)
sample_fun = stochtrace.sampler_normal(x_like, num=10_000)
First, compute the diagonal.
In [5]:
Copied!
problem = stochtrace.integrand_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
diagonal_ctrl = estimate(lambda v: A @ v, jax.random.PRNGKey(1))
problem = stochtrace.integrand_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
diagonal_ctrl = estimate(lambda v: A @ v, jax.random.PRNGKey(1))
Then, compute trace and diagonal jointly using the estimate of the diagonal as a control variate.
In [6]:
Copied!
def matvec_ctrl(v):
"""Evaluate a matrix-vector product with a control variate."""
return A @ v - diagonal_ctrl * v
def matvec_ctrl(v):
"""Evaluate a matrix-vector product with a control variate."""
return A @ v - diagonal_ctrl * v
In [7]:
Copied!
problem = stochtrace.integrand_trace_and_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
trace_and_diagonal = estimate(matvec_ctrl, jax.random.PRNGKey(2))
trace, diagonal = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
problem = stochtrace.integrand_trace_and_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
trace_and_diagonal = estimate(matvec_ctrl, jax.random.PRNGKey(2))
trace, diagonal = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
We can, of course, compute it without a control variate as well.
In [8]:
Copied!
problem = stochtrace.integrand_trace_and_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
trace_and_diagonal = estimate(lambda v: A @ v, jax.random.PRNGKey(2))
trace_ref, diagonal_ref = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
problem = stochtrace.integrand_trace_and_diagonal()
estimate = stochtrace.estimator(problem, sample_fun)
trace_and_diagonal = estimate(lambda v: A @ v, jax.random.PRNGKey(2))
trace_ref, diagonal_ref = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
Compare the results. First, the diagonal.
In [9]:
Copied!
print("True value:", jnp.diag(A))
print("Control variate:", diagonal_ctrl, jnp.linalg.norm(jnp.diag(A) - diagonal_ctrl))
print("Approximation:", diagonal_ref, jnp.linalg.norm(jnp.diag(A) - diagonal_ref))
print(
"Control-variate approximation:",
diagonal + diagonal_ctrl,
jnp.linalg.norm(jnp.diag(A) - diagonal - diagonal_ctrl),
)
print("True value:", jnp.diag(A))
print("Control variate:", diagonal_ctrl, jnp.linalg.norm(jnp.diag(A) - diagonal_ctrl))
print("Approximation:", diagonal_ref, jnp.linalg.norm(jnp.diag(A) - diagonal_ref))
print(
"Control-variate approximation:",
diagonal + diagonal_ctrl,
jnp.linalg.norm(jnp.diag(A) - diagonal - diagonal_ctrl),
)
True value: [ 1. 6. 11. 16.] Control variate: [ 1.0441655 5.7610655 10.704792 15.640959 ] 0.52449846 Approximation: [ 1.0695375 5.773019 11.284297 15.853068 ] 0.39845905 Control-variate approximation: [ 1.0738611 5.8862205 11.102717 15.987889 ] 0.17058367
Then, the trace.
In [10]:
Copied!
print("True value:", jnp.trace(A))
print(
"Control variate:",
jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - jnp.sum(diagonal_ctrl)),
)
print("Approximation:", trace_ref, jnp.abs(jnp.trace(A) - trace_ref))
print(
"Control variate approximation:",
trace + jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - trace - jnp.sum(diagonal_ctrl)),
)
print("True value:", jnp.trace(A))
print(
"Control variate:",
jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - jnp.sum(diagonal_ctrl)),
)
print("Approximation:", trace_ref, jnp.abs(jnp.trace(A) - trace_ref))
print(
"Control variate approximation:",
trace + jnp.sum(diagonal_ctrl),
jnp.abs(jnp.trace(A) - trace - jnp.sum(diagonal_ctrl)),
)
True value: 34.0 Control variate: 33.15098 0.8490181 Approximation: 33.97992 0.020080566 Control variate approximation: 34.05069 0.050689697