latentmi.lmi ============ .. py:module:: latentmi.lmi Classes ------- .. autoapisummary:: latentmi.lmi.EarlyStopper Functions --------- .. autoapisummary:: latentmi.lmi.train latentmi.lmi.learn_representation latentmi.lmi.estimate Module Contents --------------- .. py:class:: EarlyStopper(patience=1) Early stopping that returns best weights trying to replicate the Keras callback .. py:method:: early_stop(validation_loss, model) .. py:function:: train(model, X_train, Y_train, X_test, Y_test, batch_size=512, lr=0.0001, epochs=300, patience=30, quiet=True) training loop for LMI models :param model: LMI model :param X_train: train samples, shape (N_samples, N_features) :param Y_train: train samples, shape (N_samples, N_features) :param X_test: test samples, shape (N_samples, N_features) :param Y_test: test samples, shape (N_samples, N_features) :param batch_size: samples per batch, defaults to 512 :param lr: learning rate for Adam optimizer, defaults to 1e-4 :param epochs: max number of epochs, defaults to 300 :param patience: epochs without val. loss decline before early stopping, defaults to 300 :param quiet: suppress training progress display, defaults to True .. py:function:: learn_representation(Xs, Ys, train_indices, test_indices, regularizer='models.AECross', alpha=1, lam=1, N_dims=8, batch_size=512, lr=0.0001, epochs=300, validation_split=0.3, patience=30, quiet=True, device='cpu') train paired AE model and embed data :param Xs: :param Ys: :param train_indices: :param test_indices: :param regularizer: :param alpha: :param lam: :param N_dims: :param batch_size: :param lr: :param epochs: :param validation_split: :param patience: :param quiet: .. py:function:: estimate(Xs, Ys, regularizer='models.AECross', alpha=1, lam=1, N_dims=8, validation_split=0.5, estimate_on_val=True, batch_size=512, lr=0.0001, epochs=300, patience=30, quiet=True, device=None) return pMIs, with NaNs for points not included in KSG estimate :param Xs: :param Ys: :param regularizer: :param alpha: :param lam: :param N_dims: :param batch_size: :param lr: :param epochs: :param validation_split: :param patience: :param quiet: