{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quick start guide\n", "\n", "Here, we will test drive `latentmi` by estimating the mutual information between two multivariate Gaussians.\n", "\n", "First, we'll import the necessary packages." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from latentmi import lmi\n", "\n", "import torch\n", "\n", "torch.manual_seed(2121)\n", "np.random.seed(2121)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating synthetic multivariate Gaussian data\n", "\n", "To generate two high-dimensional multivariate Gaussians with known MI, we'll sample from one two dimensional multivariate Gaussian and project each component into 100 dimensions. Then, using the correlation between the two ``intrinsic'' components, we can analytically determine the MI between the intrinsic dimensions, which is equal to the MI between the high dimensional projections. We'll use 100 dimensions per variable and a generous $10^4$ samples." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10000, 100)\n", "(10000, 100)\n" ] } ], "source": [ "intrinsic = np.random.multivariate_normal([0, 0], cov=[[6, 3], [3, 3.5]], size=10**4)\n", "X_intrinsic = intrinsic[:, [0]]\n", "Y_intrinsic = intrinsic[:, [1]]\n", " \n", "X_proj = np.random.normal(size=(1, 100))\n", "Y_proj = np.random.normal(size=(1, 100))\n", "\n", "Xs = X_intrinsic @ X_proj\n", "Ys = Y_intrinsic @ Y_proj\n", "\n", "print(Xs.shape)\n", "print(Ys.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimating MI with the LMI approximation\n", "\n", "Next, we'll estimate the MI between the two high dimensional variables from the $10^4$ samples. The latent MI approximation involves first learning a low-dimensional representation using neural networks, then applying the [Kraskov, Stoegbauer, Grassberger](https://arxiv.org/abs/cond-mat/0305641) estimator to that learned representation. \n", "\n", "The `lmi.estimate` function wraps the whole process into one function call, and returns three things:\n", "1. Pointwise mutual information estimates (which can be averaged to obtain an MI estimate)\n", "2. Coordinates of each sample in the low-dimensional representation space\n", "3. Pytorch object for the representation learning model\n", "\n", "By default, the learned representation has 8 dimensions, though this can be increased or decreased as desired. Also, the function defaults to `quiet` so training progress is not displayed. Many other parameters of the representation learning network and training can be adjusted in the function call -- though in practice, we find extensive parameter tuning to be unnecessary.\n", "\n", "\n", "If we only care about the MI estimate, we can ignore the 2., 3., and simply average the array returned for the first output. By default, the `lmi` function only estimates MI using validation samples and not training samples, so the pointwise mutual information array will have `NaN` for each of the samples in the training set. So we have to make sure we take a mean excluding `NaN` to get the MI estimate. Numpy has a helpful `nanmean` function which does this.\n", "\n", "**Note about progress bar**\n", "\n", "Because training can be stopped early (i.e. before max epochs), we do not use a conventional progress bar. Instead, we display a new sunflower for every 10% of max epochs. When training is stopped early, we display some celebration emojis and print the final validation loss.\n", "\n", "If you are finding that training is not being stopped early, it could be worth increasing the `epochs` argument." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 187 (of max 300) 🌻🌻🌻🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 187\n", "final validation loss: 1.2238998651504516\n" ] } ], "source": [ "pmis, embedding, model = lmi.estimate(Xs, Ys, quiet=False,\n", " # N_dims=8, validation_split=0.5,...\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ nan 0.80080795 nan ... nan 0.11582631\n", " -0.20801842]\n" ] } ], "source": [ "print(pmis)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there are some `NaN` in the pointwise mutual information array. If we take the mean excluding `NaN`, we get our estimate." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LMI estimate: 0.433\n" ] } ], "source": [ "mi_estimate = np.nanmean(pmis)\n", "print(\"LMI estimate: %.3f\" % mi_estimate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can compare this to the analytically determined ground truth (computed like [this](https://stats.stackexchange.com/questions/438607/mutual-information-between-subsets-of-variables-in-the-multivariate-normal-distr))." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4036774610288021" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-0.5*np.log2((1-(3/(np.sqrt(6*3.5)))**2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not too bad!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Building confidence in an estimate using shuffle tests\n", "\n", "It is useful to compare to estimates to \"negative control\" estimates, where data is shuffled to break any real dependence. To be clear, by shuffling data, we mean permuting samples of $X$ such that they are uncorrelated with $Y$. This can tell us about the significance of the estimated value, in the sense of how likely it is to be the result of random chance if variables are uncorrelated.\n", "\n", "If the actual LMI estimate is greater than the shuffled estimates, we can have confidence that there is true dependence in the data. There are a number of ways one could quantify this (e.g. Z-score approximating shuffled test results as normally distributed)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 128 (of max 300) 🌻🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 128\n", "final validation loss: 1.9928849935531616\n", "epoch 133 (of max 300) 🌻🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 133\n", "final validation loss: 1.9621137261390686\n", "epoch 90 (of max 300) 🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 90\n", "final validation loss: 1.993351674079895\n", "epoch 75 (of max 300) 🌻🌻 🎉🎉\n", "success! training stopped at epoch 75\n", "final validation loss: 2.009299099445343\n", "epoch 114 (of max 300) 🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 114\n", "final validation loss: 1.983246123790741\n", "epoch 138 (of max 300) 🌻🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 138\n", "final validation loss: 1.9786343097686767\n", "epoch 105 (of max 300) 🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 105\n", "final validation loss: 1.9989253401756286\n", "epoch 73 (of max 300) 🌻🌻 🎉🎉\n", "success! training stopped at epoch 73\n", "final validation loss: 2.0429690957069395\n", "epoch 103 (of max 300) 🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 103\n", "final validation loss: 2.0082654118537904\n", "epoch 87 (of max 300) 🌻🌻 🎉🎉\n", "success! training stopped at epoch 87\n", "final validation loss: 1.9687854886054992\n", "\n", "Shuffle test results:\n", "[0.003509357472832474, -0.009113338161662298, 0.004744177031334124, -0.002546823743237671, -0.02288156076957172, 0.0004103898407831636, 0.00835833717483701, -0.0034016581464100283, 0.008932102414160346, 0.01168538375038173]\n" ] } ], "source": [ "N_shuffles = 10\n", "\n", "shuffle_test_results = []\n", "\n", "for _ in range(N_shuffles):\n", "\n", " np.random.shuffle(Xs) # by default shuffles only first axis\n", " \n", " pmis, _, _ = lmi.estimate(Xs, Ys, quiet=False)\n", "\n", " shuffle_test_results.append(np.nanmean(pmis))\n", "\n", "print()\n", "print(\"Shuffle test results:\")\n", "print(shuffle_test_results)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Z score of estimate, from normal approximation of shuffled estimates:\n", "44.56818373761494\n" ] } ], "source": [ "shuffle_mu = np.mean(shuffle_test_results)\n", "shuffle_sigma = np.std(shuffle_test_results)\n", "\n", "print(\"Z score of estimate, from normal approximation of shuffled estimates:\")\n", "print((mi_estimate-shuffle_mu)/shuffle_sigma)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }