import numpy as np
import torch
import torch.nn as nn
[docs]
class AECross(nn.Module):
def __init__(self, x_dim, y_dim, latent_size, alpha=1, lam=1):
"""
Paired AE models with cross predictive regularization
"""
super(AECross, self).__init__()
# choosing hidden layer sizes
# Lx = int(2**np.floor(np.log2(x_dim)))
# Ly = int(2**np.floor(np.log2(y_dim)))
Lx, Ly = 1024, 1024
if x_dim > 2048:
Lx = int(2**np.floor(np.log2(x_dim)))
if y_dim > 2048:
Ly = int(2**np.floor(np.log2(y_dim)))
self.x_encoder = nn.Sequential(nn.Linear(x_dim, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.y_encoder = nn.Sequential(nn.Linear(y_dim, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.xx_decoder = nn.Sequential(nn.Linear(latent_size, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, x_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.yy_decoder = nn.Sequential(nn.Linear(latent_size, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, y_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.yx_decoder = nn.Sequential(nn.Dropout(p=0.5),
nn.Linear(latent_size, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Lx//4, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Lx//2, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Lx, x_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.xy_decoder = nn.Sequential(nn.Dropout(p=0.5),
nn.Linear(latent_size, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Ly//4, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Ly//2, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(p=0.5),
nn.Linear(Ly, y_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.alpha = alpha
self.lam = lam
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.0)
for net in [self.x_encoder, self.y_encoder, self.xx_decoder, self.yy_decoder,
self.xy_decoder, self.yx_decoder]:
net.apply(init_weights)
[docs]
def encode(self, x_samples, y_samples):
Zx = self.x_encoder(x_samples)
Zy = self.y_encoder(y_samples)
return Zx, Zy
[docs]
def cross_decode(self, Zx, Zy):
y_hat = self.xy_decoder(Zx)
x_hat = self.yx_decoder(Zy)
return x_hat, y_hat
[docs]
def decode(self, Zx, Zy):
x_hat = self.xx_decoder(Zx)
y_hat = self.yy_decoder(Zy)
return x_hat, y_hat
[docs]
def rec_loss(self, hat, samples):
return torch.nn.functional.mse_loss(hat, samples, reduction='mean')
[docs]
def learning_loss(self, x_samples, y_samples):
Zx, Zy = self.encode(x_samples, y_samples)
Xh, Yh = self.decode(Zx, Zy)
cXh, cYh = self.cross_decode(Zx, Zy)
auto_loss = self.rec_loss(Xh, x_samples) + self.rec_loss(Yh, y_samples)
cross_loss = self.rec_loss(cXh, x_samples) + self.rec_loss(cYh, y_samples)
# print(auto_loss, cross_loss)
return self.lam*cross_loss + self.alpha*auto_loss
[docs]
class AEMINE(nn.Module):
def __init__(self, x_dim, y_dim, latent_size, alpha=1, lam=1):
"""
Paired AE models with MINE regularization
"""
super(AEMINE, self).__init__()
# choosing hidden layer sizes
# Lx = int(2**np.floor(np.log2(x_dim)))
# Ly = int(2**np.floor(np.log2(y_dim)))
Lx, Ly = 1024, 1024
self.x_encoder = nn.Sequential(nn.Linear(x_dim, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.y_encoder = nn.Sequential(nn.Linear(y_dim, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.xx_decoder = nn.Sequential(nn.Linear(latent_size, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, x_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.yy_decoder = nn.Sequential(nn.Linear(latent_size, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, y_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.T_func = nn.Sequential(nn.Linear(latent_size*2, latent_size),
nn.ReLU(),
nn.Linear(latent_size, 1))
self.alpha = alpha
self.lam = lam
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.0)
for net in [self.x_encoder, self.y_encoder, self.xx_decoder, self.yy_decoder,
self.T_func]:
net.apply(init_weights)
[docs]
def encode(self, x_samples, y_samples):
Zx = self.x_encoder(x_samples)
Zy = self.y_encoder(y_samples)
return Zx, Zy
[docs]
def MINELoss(self, x_samples, y_samples):
sample_size = y_samples.shape[0]
random_index = torch.randint(sample_size, (sample_size,)).long()
y_shuffle = y_samples[random_index]
T0 = self.T_func(torch.cat([x_samples,y_samples], dim = -1))
T1 = self.T_func(torch.cat([x_samples,y_shuffle], dim = -1))
lower_bound = T0.mean() - torch.log(T1.exp().mean())
return -lower_bound
[docs]
def decode(self, Zx, Zy):
x_hat = self.xx_decoder(Zx)
y_hat = self.yy_decoder(Zy)
return x_hat, y_hat
[docs]
def rec_loss(self, hat, samples):
return torch.nn.functional.mse_loss(hat, samples, reduction='mean')
[docs]
def learning_loss(self, x_samples, y_samples):
Zx, Zy = self.encode(x_samples, y_samples)
Xh, Yh = self.decode(Zx, Zy)
auto_loss = self.rec_loss(Xh, x_samples) + self.rec_loss(Yh, y_samples)
cross_loss = self.MINELoss(Zx, Zy)
return self.lam*cross_loss + self.alpha*auto_loss
[docs]
class AEInfoNCE(nn.Module):
def __init__(self, x_dim, y_dim, latent_size, alpha=1, lam=1):
"""
Paired AE models with InfoNCE regularization
"""
super(AEInfoNCE, self).__init__()
# choosing hidden layer sizes
# Lx = int(2**np.floor(np.log2(x_dim)))
# Ly = int(2**np.floor(np.log2(y_dim)))
Lx, Ly = 1024, 1024
self.x_encoder = nn.Sequential(nn.Linear(x_dim, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.y_encoder = nn.Sequential(nn.Linear(y_dim, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, latent_size),
nn.LeakyReLU(negative_slope=0.2))
self.xx_decoder = nn.Sequential(nn.Linear(latent_size, Lx//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//4, Lx//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx//2, Lx),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Lx, x_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.yy_decoder = nn.Sequential(nn.Linear(latent_size, Ly//4),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//4, Ly//2),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly//2, Ly),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(Ly, y_dim),)
# nn.LeakyReLU(negative_slope=0.2))
self.F_func = nn.Sequential(nn.Linear(latent_size+latent_size, latent_size),
nn.ReLU(),
nn.Linear(latent_size, 1),
nn.Softplus())
self.alpha = alpha
self.lam = lam
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.0)
for net in [self.x_encoder, self.y_encoder, self.xx_decoder, self.yy_decoder,
self.F_func]:
net.apply(init_weights)
[docs]
def encode(self, x_samples, y_samples):
Zx = self.x_encoder(x_samples)
Zy = self.y_encoder(y_samples)
return Zx, Zy
[docs]
def InfoNCELoss(self, x_samples, y_samples):
sample_size = y_samples.shape[0]
x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1))
y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1))
T0 = self.F_func(torch.cat([x_samples,y_samples], dim = -1))
T1 = self.F_func(torch.cat([x_tile, y_tile], dim = -1)) #[sample_size, sample_size, 1]
lower_bound = T0.mean() - (T1.logsumexp(dim = 1).mean() - np.log(sample_size))
return -lower_bound
[docs]
def decode(self, Zx, Zy):
x_hat = self.xx_decoder(Zx)
y_hat = self.yy_decoder(Zy)
return x_hat, y_hat
[docs]
def rec_loss(self, hat, samples):
return torch.nn.functional.mse_loss(hat, samples, reduction='mean')
[docs]
def learning_loss(self, x_samples, y_samples):
Zx, Zy = self.encode(x_samples, y_samples)
Xh, Yh = self.decode(Zx, Zy)
auto_loss = self.rec_loss(Xh, x_samples) + self.rec_loss(Yh, y_samples)
cross_loss = self.InfoNCELoss(Zx, Zy)
return self.lam*cross_loss + self.alpha*auto_loss