import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from domainlab.utils.utils_classif import logit2preds_vpic
from domid.compos.cnn_VAE import ConvolutionalDecoder, ConvolutionalEncoder
from domid.compos.linear_VAE import LinearDecoder, LinearEncoder
from domid.models.a_model_cluster import AModelCluster
# from tensorboardX import SummaryWriter
[docs]def mk_vade(parent_class=AModelCluster):
class ModelVaDE(parent_class):
def __init__(
self,
zd_dim,
d_dim,
device,
i_c,
i_h,
i_w,
bs,
L=5,
random_batching=False,
model_method="cnn",
prior="Bern",
dim_inject_y=0,
pre_tr_weight_path=None,
feat_extract="vae",
):
"""
VaDE model (Jiang et al. 2017 "Variational Deep Embedding:
An Unsupervised and Generative Approach to Clustering") with
fully connected encoder and decoder.
:param zd_dim: dimension of the latent space
:param d_dim: number of clusters for the clustering task
:param device: device to use, e.g., "cuda" or "cpu"
:param i_c: number of channels of the input image
:param i_h: height of the input image
:param i_w: width of the input image
:param args: command line arguments
"""
super(ModelVaDE, self).__init__()
self.zd_dim = zd_dim
self.d_dim = d_dim
self.device = device
self.L = L
self.loss_epoch = 0
self.prior = prior
self.dim_inject_y = dim_inject_y
self.model_method = model_method
self.model = "vade"
self.random_batching = random_batching
self.bs = bs
self.pre_tr_weight_path = pre_tr_weight_path
self.feat_extract = feat_extract
# if self.args.dim_inject_y:
# self.dim_inject_y = self.args.dim_inject_y
# self.dim_inject_domain = 0
# if self.args.path_to_domain: # FIXME: one can simply read from the file to find out the injected dimension
# self.dim_inject_domain = args.d_dim # FIXME: allow arbitrary domain vector to be injected
# if self.args.model_method == "linear":
if self.model_method == "linear":
self.encoder = LinearEncoder(zd_dim=zd_dim, input_dim=(i_c, i_h, i_w)).to(device)
self.decoder = LinearDecoder(prior=prior, zd_dim=zd_dim, input_dim=(i_c, i_h, i_w)).to(device)
if self.dim_inject_y > 0:
warnings.warn("linear model decoder does not support label injection")
else:
self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device)
self.decoder = ConvolutionalDecoder(
prior=prior,
zd_dim=zd_dim, # 50
domain_dim=self.dim_inject_y,
h_dim=self.encoder.h_dim,
num_channels=i_c,
).to(device)
print(self.encoder)
print(self.decoder)
self.log_pi = nn.Parameter(
torch.FloatTensor(
self.d_dim,
)
.fill_(1.0 / self.d_dim)
.log(),
requires_grad=True,
)
self.mu_c = nn.Parameter(torch.FloatTensor(self.d_dim, self.zd_dim).fill_(0), requires_grad=True)
self.log_sigma2_c = nn.Parameter(torch.FloatTensor(self.d_dim, self.zd_dim).fill_(0), requires_grad=True)
# self.loss_writter = SummaryWriter()
def _inference(self, x):
"""Auxiliary function for inference
:param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim].
:return tensor preds_c: One hot encoded tensor of the predicted cluster assignment (shape: [batch_size, self.d_dim]).
:return tensor probs_c: Tensor of the predicted cluster probabilities; this is q(c|x) per eq. (16) or gamma_c in eq. (12) (shape: [batch_size, self.d_dim]).
:return tensor z: Tensor of the latent space representation (shape: [batch_size, self.zd_dim])
:return tensor z_mu: Tensor of the mean of the latent space representation (shape: [batch_size, self.zd_dim])
:return tensor z_sigma2_log: Tensor of the log of the variance of the latent space representation (shape: [batch_size, self.zd_dim])
:return tensor mu_c: Tensor of the estimated cluster means (shape: [self.d_dim, self.zd_dim])
:return tensor log_sigma2_c: Tensor of the estimated cluster variances (shape: [self.d_dim, self.zd_dim])
:return tensor pi: Tensor of the estimated cluster prevalences, p(c) (shape: [self.d_dim])
:return tensor logits: Tensor where each column contains the log-probability p(c)p(z|c) for cluster c=0,...,self.d_dim-1 (shape: [batch_size, self.d_dim]).
"""
z_mu, z_sigma2_log = self.encoder(x)
z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
pi = F.softmax(self.log_pi, dim=0)
mu_c = self.mu_c
log_sigma2_c = self.log_sigma2_c
logits = torch.log(pi.unsqueeze(0)) + self.gaussian_pdfs_log(z, mu_c, log_sigma2_c)
# shape [batch_size, self.d_dim], each column contains the log-probability p(c)p(z|c) for cluster c=0,...,self.d_dim-1.
preds_c, probs_c, *_ = logit2preds_vpic(logits)
return preds_c, probs_c, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits
def infer_d_v_2(self, x, inject_domain):
"""
Used for tensorboard visualizations only.
"""
results = self._inference(x)
if len(inject_domain) > 0:
zy = torch.cat((results[2], inject_domain), 1)
else:
zy = results[2]
x_pro, *_ = self.decoder(zy)
preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = (r.cpu().detach() for r in results)
return preds, z_mu, z, log_sigma2_c, probs, x_pro
def cal_loss(self, x, inject_domain, warmup_beta):
"""Function that is called in trainer_vade to calculate loss
:param x: tensor with input data
:return: ELBO loss
"""
return self._cal_ELBO_loss(x, inject_domain, warmup_beta)
def _cal_reconstruction_loss(self, x, inject_tensor=[]):
z_mu = self.encoder.get_z(x)
z_sigma2_log = self.encoder.get_log_sigma2(x)
z = z_mu
if len(inject_tensor) > 0:
zy = torch.cat((z, inject_tensor), 1)
else:
zy = z
x_pro, *_ = self.decoder(zy)
if self.prior == "Bern":
L_rec = F.binary_cross_entropy(x_pro, x)
else:
sigma = torch.Tensor([0.9]).to(self.device) # mean sigma of all images
log_sigma_est = torch.log(sigma).to(self.device)
L_rec = torch.mean(torch.sum(torch.sum(torch.sum(0.5 * (x - x_pro) ** 2, 2), 2), 1), 0) / sigma**2
return L_rec
def _cal_reconstruction_loss_helper(self, x, x_pro, log_sigma):
if self.prior == "Bern":
L_rec = F.binary_cross_entropy(x_pro, x)
else:
sigma = torch.Tensor([0.9]).to(self.device) # mean sigma of all images
log_sigma_est = torch.log(sigma).to(self.device)
L_rec = torch.mean(torch.sum(torch.sum(torch.sum(0.5 * (x - x_pro) ** 2, 2), 2), 1), 0) / sigma**2
return L_rec
def _cal_ELBO_loss(self, x, inject_domain, warmup_beta):
"""ELBO loss function
Using SGVB estimator and the reparametrization trick calculates ELBO loss.
Calculates loss between encoded input and input using ELBO equation (12) in the paper.
:param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim].
:param int L: Number of Monte Carlo samples in the SGVB
"""
preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = self._inference(x)
# mu, sigma from the decoder
eps = 1e-10
L_rec = 0.0
for l in range(self.L):
z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu # shape [batch_size, self.zd_dim]4
if len(inject_domain) > 0:
zy = torch.cat((z, inject_domain), 1)
else:
zy = z
x_pro, log_sigma = self.decoder(zy) # x_pro, mu, sigma
L_rec += self._cal_reconstruction_loss_helper(x, x_pro, log_sigma) # FIXME
L_rec /= self.L
Loss = L_rec * x.size(1)
# --> this is the -"first line" of eq (12) in the paper with additional averaging over the batch.
Loss += (
0.5
* warmup_beta
* torch.mean(
torch.sum(
probs
* torch.sum(
log_sigma2_c.unsqueeze(0)
+ torch.exp(z_sigma2_log.unsqueeze(1) - log_sigma2_c.unsqueeze(0))
+ (z_mu.unsqueeze(1) - mu_c.unsqueeze(0)).pow(2) / torch.exp(log_sigma2_c.unsqueeze(0)),
2,
),
1,
)
)
)
# inner sum dimentions:
# [1, d_dim, zd_dim] + exp([batch_size, 1, zd_dim] - [1, d_dim, zd_dim]) + ([batch_size, 1, zd_dim] - [1, d_dim, zd_dim])^2 / exp([1, d_dim, zd_dim])
# = [batch_size, d_dim, zd_dim] -> sum of zd_dim dimensions
# the next sum is over d_dim dimensions
# the mean is over the batch
# --> overall, this is -"second line of eq. (12)" with additional mean over the batch
Loss -= warmup_beta * torch.mean(
torch.sum(probs * torch.log(pi.unsqueeze(0) / (probs + eps)), 1)
) # FIXME: (+eps) is a hack to avoid NaN. Is there a better way?
# dimensions: [batch_size, d_dim] * log([1, d_dim] / [batch_size, d_dim]), where the sum is over d_dim dimensions --> [batch_size] --> mean over the batch --> a scalar
Loss -= 0.5 * warmup_beta * torch.mean(torch.sum(1.0 + z_sigma2_log, 1))
# dimensions: mean( sum( [batch_size, zd_dim], 1 ) ) where the sum is over zd_dim dimensions and mean over the batch
# --> overall, this is -"third line of eq. (12)" with additional mean over the batch
return Loss
def gaussian_pdfs_log(self, x, mus, log_sigma2s):
"""helper function"""
loglik = []
for c in range(self.d_dim):
loglik.append(self.gaussian_pdf_log(x, mus[c, :], log_sigma2s[c, :]).view(-1, 1))
return torch.cat(loglik, 1)
@staticmethod
def gaussian_pdf_log(x, mu, log_sigma2):
"""
subhelper function just one gausian pdf log calculation, used as a basis for gaussian_pdfs_log function
:param x: tensor of shape [batch_size, self.zd_dim]
:param mu: mean for the cluster distribution
:param log_sigma2: variance parameters of the cluster distribtion
:return: tensor with the Gaussian log probabilities of the shape of [batch_size, 1]
"""
return -0.5 * (
torch.sum(
np.log(np.pi * 2) + log_sigma2 + (x - mu).pow(2) / torch.exp(log_sigma2),
1,
)
)
return ModelVaDE
[docs]def test_fun(d_dim, zd_dim, device):
device = torch.device("cpu")
model = ModelVaDE(d_dim=d_dim, zd_dim=zd_dim, device=device)
x = torch.rand(2, 3, 28, 28)
import numpy as np
a = np.zeros((2, 10))
a = np.double(a)
a[0, 1] = 1.0
a[1, 8] = 1.0
a
y = torch.tensor(a, dtype=torch.float)
model(x, y)
model.cal_loss(x)