{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example usage\n", "\n", "Here, we will test drive `latentmi` by estimating the mutual information between two multivariate Gaussians.\n", "\n", "First, we'll import the necessary packages." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from latentmi import lmi\n", "\n", "import torch\n", "\n", "torch.manual_seed(2121)\n", "np.random.seed(2121)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating synthetic multivariate Gaussian data\n", "\n", "To generate two high-dimensional multivariate Gaussians with known MI, we'll sample from one two dimensional multivariate Gaussian and project each component into 100 dimensions. Then, using the correlation between the two ``intrinsic'' components, we can analytically determine the MI between the intrinsic dimensions, which is equal to the MI between the high dimensional projections. We'll use 100 dimensions per variable and a generous $10^4$ samples." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10000, 100)\n", "(10000, 100)\n" ] } ], "source": [ "intrinsic = np.random.multivariate_normal([0, 0], cov=[[6, 3], [3, 3.5]], size=10**4)\n", "X_intrinsic = intrinsic[:, [0]]\n", "Y_intrinsic = intrinsic[:, [1]]\n", " \n", "X_proj = np.random.normal(size=(1, 100))\n", "Y_proj = np.random.normal(size=(1, 100))\n", "\n", "Xs = X_intrinsic @ X_proj\n", "Ys = Y_intrinsic @ Y_proj\n", "\n", "print(Xs.shape)\n", "print(Ys.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimating MI with the LMI approximation\n", "\n", "Next, we'll estimate the MI between the two high dimensional variables from the $10^4$ samples. The latent MI approximation involves first learning a low-dimensional representation using neural networks, then applying the [Kraskov, Stoegbauer, Grassberger](https://arxiv.org/abs/cond-mat/0305641) estimator to that learned representation. \n", "\n", "The `lmi` function wraps the whole process into one function call, and returns three things:\n", "1. Pointwise mutual information estimates (which can be averaged to obtain an MI estimate)\n", "2. Coordinates of each sample in the low-dimensional representation space\n", "3. Pytorch object for the representation learning model\n", "\n", "By default, the learned representation has 8 dimensions, though this can be increased or decreased as desired. Also, the function defaults to `quiet` so training progress is not displayed. Many other parameters of the representation learning network and training can be adjusted in the function call -- though in practice, we find extensive parameter tuning to be unnecessary.\n", "\n", "\n", "If we only care about the MI estimate, we can ignore the 2., 3., and simply average the array returned for the first output. By default, the `lmi` function only estimates MI using validation samples and not training samples, so the pointwise mutual information array will have `NaN` for each of the samples in the training set. So we have to make sure we take a mean excluding `NaN` to get the MI estimate. Numpy has a helpful `nanmean` function which does this." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 187 (of max 300) 🌻🌻🌻🌻🌻🌻 🎉🎉\n", "success! training stopped at epoch 187\n", "final validation loss: 1.2238998651504516\n" ] } ], "source": [ "pmis, embedding, model = lmi.estimate(Xs, Ys, quiet=False,\n", " # N_dims=8, validation_split=0.5,...\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ nan 0.80080795 nan ... nan 0.11582631\n", " -0.20801842]\n" ] } ], "source": [ "print(pmis)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there are some `NaN` in the pointwise mutual information array. If we take the mean excluding `NaN`, we get our estimate." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LMI estimate: 0.433\n" ] } ], "source": [ "print(\"LMI estimate: %.3f\" % np.nanmean(pmis))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can compare this to the analytically determined ground truth (computed like [this](https://stats.stackexchange.com/questions/438607/mutual-information-between-subsets-of-variables-in-the-multivariate-normal-distr))." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4036774610288021" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-0.5*np.log2((1-(3/(np.sqrt(6*3.5)))**2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not too bad!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }