Quick start guide¶
Here, we will test drive latentmi by estimating the mutual information between two multivariate Gaussians.
First, we’ll import the necessary packages.
import numpy as np
from latentmi import lmi
import torch
torch.manual_seed(2121)
np.random.seed(2121)
Generating synthetic multivariate Gaussian data¶
To generate two high-dimensional multivariate Gaussians with known MI, we’ll sample from one two dimensional multivariate Gaussian and project each component into 100 dimensions. Then, using the correlation between the two ``intrinsic’’ components, we can analytically determine the MI between the intrinsic dimensions, which is equal to the MI between the high dimensional projections. We’ll use 100 dimensions per variable and a generous $10^4$ samples.
intrinsic = np.random.multivariate_normal([0, 0], cov=[[6, 3], [3, 3.5]], size=10**4)
X_intrinsic = intrinsic[:, [0]]
Y_intrinsic = intrinsic[:, [1]]
X_proj = np.random.normal(size=(1, 100))
Y_proj = np.random.normal(size=(1, 100))
Xs = X_intrinsic @ X_proj
Ys = Y_intrinsic @ Y_proj
print(Xs.shape)
print(Ys.shape)
(10000, 100)
(10000, 100)
Estimating MI with the LMI approximation¶
Next, we’ll estimate the MI between the two high dimensional variables from the $10^4$ samples. The latent MI approximation involves first learning a low-dimensional representation using neural networks, then applying the Kraskov, Stoegbauer, Grassberger estimator to that learned representation.
The lmi.estimate function wraps the whole process into one function call, and returns three things:
Pointwise mutual information estimates (which can be averaged to obtain an MI estimate)
Coordinates of each sample in the low-dimensional representation space
Pytorch object for the representation learning model
By default, the learned representation has 8 dimensions, though this can be increased or decreased as desired. Also, the function defaults to quiet so training progress is not displayed. Many other parameters of the representation learning network and training can be adjusted in the function call – though in practice, we find extensive parameter tuning to be unnecessary.
If we only care about the MI estimate, we can ignore the 2., 3., and simply average the array returned for the first output. By default, the lmi function only estimates MI using validation samples and not training samples, so the pointwise mutual information array will have NaN for each of the samples in the training set. So we have to make sure we take a mean excluding NaN to get the MI estimate. Numpy has a helpful nanmean function which does this.
Note about progress bar
Because training can be stopped early (i.e. before max epochs), we do not use a conventional progress bar. Instead, we display a new sunflower for every 10% of max epochs. When training is stopped early, we display some celebration emojis and print the final validation loss.
If you are finding that training is not being stopped early, it could be worth increasing the epochs argument.
pmis, embedding, model = lmi.estimate(Xs, Ys, quiet=False,
# N_dims=8, validation_split=0.5,...
)
epoch 187 (of max 300) 🌻🌻🌻🌻🌻🌻 🎉🎉
success! training stopped at epoch 187
final validation loss: 1.2238998651504516
print(pmis)
[ nan 0.80080795 nan ... nan 0.11582631
-0.20801842]
As you can see, there are some NaN in the pointwise mutual information array. If we take the mean excluding NaN, we get our estimate.
mi_estimate = np.nanmean(pmis)
print("LMI estimate: %.3f" % mi_estimate)
LMI estimate: 0.433
And then we can compare this to the analytically determined ground truth (computed like this).
-0.5*np.log2((1-(3/(np.sqrt(6*3.5)))**2))
0.4036774610288021
Not too bad!
Building confidence in an estimate using shuffle tests¶
It is useful to compare to estimates to “negative control” estimates, where data is shuffled to break any real dependence. To be clear, by shuffling data, we mean permuting samples of $X$ such that they are uncorrelated with $Y$. This can tell us about the significance of the estimated value, in the sense of how likely it is to be the result of random chance if variables are uncorrelated.
If the actual LMI estimate is greater than the shuffled estimates, we can have confidence that there is true dependence in the data. There are a number of ways one could quantify this (e.g. Z-score approximating shuffled test results as normally distributed).
N_shuffles = 10
shuffle_test_results = []
for _ in range(N_shuffles):
np.random.shuffle(Xs) # by default shuffles only first axis
pmis, _, _ = lmi.estimate(Xs, Ys, quiet=False)
shuffle_test_results.append(np.nanmean(pmis))
print()
print("Shuffle test results:")
print(shuffle_test_results)
epoch 128 (of max 300) 🌻🌻🌻🌻 🎉🎉
success! training stopped at epoch 128
final validation loss: 1.9928849935531616
epoch 133 (of max 300) 🌻🌻🌻🌻 🎉🎉
success! training stopped at epoch 133
final validation loss: 1.9621137261390686
epoch 90 (of max 300) 🌻🌻🌻 🎉🎉
success! training stopped at epoch 90
final validation loss: 1.993351674079895
epoch 75 (of max 300) 🌻🌻 🎉🎉
success! training stopped at epoch 75
final validation loss: 2.009299099445343
epoch 114 (of max 300) 🌻🌻🌻 🎉🎉
success! training stopped at epoch 114
final validation loss: 1.983246123790741
epoch 138 (of max 300) 🌻🌻🌻🌻 🎉🎉
success! training stopped at epoch 138
final validation loss: 1.9786343097686767
epoch 105 (of max 300) 🌻🌻🌻 🎉🎉
success! training stopped at epoch 105
final validation loss: 1.9989253401756286
epoch 73 (of max 300) 🌻🌻 🎉🎉
success! training stopped at epoch 73
final validation loss: 2.0429690957069395
epoch 103 (of max 300) 🌻🌻🌻 🎉🎉
success! training stopped at epoch 103
final validation loss: 2.0082654118537904
epoch 87 (of max 300) 🌻🌻 🎉🎉
success! training stopped at epoch 87
final validation loss: 1.9687854886054992
Shuffle test results:
[0.003509357472832474, -0.009113338161662298, 0.004744177031334124, -0.002546823743237671, -0.02288156076957172, 0.0004103898407831636, 0.00835833717483701, -0.0034016581464100283, 0.008932102414160346, 0.01168538375038173]
shuffle_mu = np.mean(shuffle_test_results)
shuffle_sigma = np.std(shuffle_test_results)
print("Z score of estimate, from normal approximation of shuffled estimates:")
print((mi_estimate-shuffle_mu)/shuffle_sigma)
Z score of estimate, from normal approximation of shuffled estimates:
44.56818373761494