import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from domainlab.utils.utils_classif import logit2preds_vpic
from domid.compos.cnn_VAE import ConvolutionalDecoder, ConvolutionalEncoder
from domid.compos.linear_VAE import LinearDecoder, LinearEncoder
from domid.models.a_model_cluster import AModelCluster
# from tensorboardX import SummaryWriter
[docs]def mk_vade(parent_class=AModelCluster):
    class ModelVaDE(parent_class):
        def __init__(
            self,
            zd_dim,
            d_dim,
            device,
            i_c,
            i_h,
            i_w,
            bs,
            L=5,
            random_batching=False,
            model_method="cnn",
            prior="Bern",
            dim_inject_y=0,
            pre_tr_weight_path=None,
            feat_extract="vae",
        ):
            """
            VaDE model (Jiang et al. 2017 "Variational Deep Embedding:
            An Unsupervised and Generative Approach to Clustering") with
            fully connected encoder and decoder.
            :param zd_dim: dimension of the latent space
            :param d_dim: number of clusters for the clustering task
            :param device: device to use, e.g., "cuda" or "cpu"
            :param i_c: number of channels of the input image
            :param i_h: height of the input image
            :param i_w: width of the input image
            :param args: command line arguments
            """
            super(ModelVaDE, self).__init__()
            self.zd_dim = zd_dim
            self.d_dim = d_dim
            self.device = device
            self.L = L
            self.loss_epoch = 0
            self.prior = prior
            self.dim_inject_y = dim_inject_y
            self.model_method = model_method
            self.model = "vade"
            self.random_batching = random_batching
            self.bs = bs
            self.pre_tr_weight_path = pre_tr_weight_path
            self.feat_extract = feat_extract
            # if self.args.dim_inject_y:
            #     self.dim_inject_y = self.args.dim_inject_y
            # self.dim_inject_domain = 0
            # if self.args.path_to_domain:    # FIXME: one can simply read from the file to find out the injected dimension
            #     self.dim_inject_domain = args.d_dim   # FIXME: allow arbitrary domain vector to be injected
            # if self.args.model_method == "linear":
            if self.model_method == "linear":
                self.encoder = LinearEncoder(zd_dim=zd_dim, input_dim=(i_c, i_h, i_w)).to(device)
                self.decoder = LinearDecoder(prior=prior, zd_dim=zd_dim, input_dim=(i_c, i_h, i_w)).to(device)
                if self.dim_inject_y > 0:
                    warnings.warn("linear model decoder does not support label injection")
            else:
                self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device)
                self.decoder = ConvolutionalDecoder(
                    prior=prior,
                    zd_dim=zd_dim,  # 50
                    domain_dim=self.dim_inject_y,
                    h_dim=self.encoder.h_dim,
                    num_channels=i_c,
                ).to(device)
            print(self.encoder)
            print(self.decoder)
            self.log_pi = nn.Parameter(
                torch.FloatTensor(
                    self.d_dim,
                )
                .fill_(1.0 / self.d_dim)
                .log(),
                requires_grad=True,
            )
            self.mu_c = nn.Parameter(torch.FloatTensor(self.d_dim, self.zd_dim).fill_(0), requires_grad=True)
            self.log_sigma2_c = nn.Parameter(torch.FloatTensor(self.d_dim, self.zd_dim).fill_(0), requires_grad=True)
            # self.loss_writter = SummaryWriter()
        def _inference(self, x):
            """Auxiliary function for inference
            :param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim].
            :return tensor preds_c: One hot encoded tensor of the predicted cluster assignment (shape: [batch_size, self.d_dim]).
            :return tensor probs_c: Tensor of the predicted cluster probabilities; this is q(c|x) per eq. (16) or gamma_c in eq. (12) (shape: [batch_size, self.d_dim]).
            :return tensor z: Tensor of the latent space representation (shape: [batch_size, self.zd_dim])
            :return tensor z_mu: Tensor of the mean of the latent space representation (shape: [batch_size, self.zd_dim])
            :return tensor z_sigma2_log: Tensor of the log of the variance of the latent space representation (shape: [batch_size, self.zd_dim])
            :return tensor mu_c: Tensor of the estimated cluster means (shape: [self.d_dim, self.zd_dim])
            :return tensor log_sigma2_c: Tensor of the estimated cluster variances (shape: [self.d_dim, self.zd_dim])
            :return tensor pi: Tensor of the estimated cluster prevalences, p(c) (shape: [self.d_dim])
            :return tensor logits: Tensor where each column contains the log-probability p(c)p(z|c) for cluster c=0,...,self.d_dim-1 (shape: [batch_size, self.d_dim]).
            """
            z_mu, z_sigma2_log = self.encoder(x)
            z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
            pi = F.softmax(self.log_pi, dim=0)
            mu_c = self.mu_c
            log_sigma2_c = self.log_sigma2_c
            logits = torch.log(pi.unsqueeze(0)) + self.gaussian_pdfs_log(z, mu_c, log_sigma2_c)
            # shape [batch_size, self.d_dim], each column contains the log-probability p(c)p(z|c) for cluster c=0,...,self.d_dim-1.
            preds_c, probs_c, *_ = logit2preds_vpic(logits)
            return preds_c, probs_c, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits
        def infer_d_v_2(self, x, inject_domain):
            """
            Used for tensorboard visualizations only.
            """
            results = self._inference(x)
            if len(inject_domain) > 0:
                zy = torch.cat((results[2], inject_domain), 1)
            else:
                zy = results[2]
            x_pro, *_ = self.decoder(zy)
            preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = (r.cpu().detach() for r in results)
            return preds, z_mu, z, log_sigma2_c, probs, x_pro
        def cal_loss(self, x, inject_domain, warmup_beta):
            """Function that is called in trainer_vade to calculate loss
            :param x: tensor with input data
            :return: ELBO loss
            """
            return self._cal_ELBO_loss(x, inject_domain, warmup_beta)
        def _cal_reconstruction_loss(self, x, inject_tensor=[]):
            z_mu = self.encoder.get_z(x)
            z_sigma2_log = self.encoder.get_log_sigma2(x)
            z = z_mu
            if len(inject_tensor) > 0:
                zy = torch.cat((z, inject_tensor), 1)
            else:
                zy = z
            x_pro, *_ = self.decoder(zy)
            if self.prior == "Bern":
                L_rec = F.binary_cross_entropy(x_pro, x)
            else:
                sigma = torch.Tensor([0.9]).to(self.device)  # mean sigma of all images
                log_sigma_est = torch.log(sigma).to(self.device)
                L_rec = torch.mean(torch.sum(torch.sum(torch.sum(0.5 * (x - x_pro) ** 2, 2), 2), 1), 0) / sigma**2
            return L_rec
        def _cal_reconstruction_loss_helper(self, x, x_pro, log_sigma):
            if self.prior == "Bern":
                L_rec = F.binary_cross_entropy(x_pro, x)
            else:
                sigma = torch.Tensor([0.9]).to(self.device)  # mean sigma of all images
                log_sigma_est = torch.log(sigma).to(self.device)
                L_rec = torch.mean(torch.sum(torch.sum(torch.sum(0.5 * (x - x_pro) ** 2, 2), 2), 1), 0) / sigma**2
            return L_rec
        def _cal_ELBO_loss(self, x, inject_domain, warmup_beta):
            """ELBO loss function
            Using SGVB estimator and the reparametrization trick calculates ELBO loss.
            Calculates loss between encoded input and input using ELBO equation (12) in the paper.
            :param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim].
            :param int L: Number of Monte Carlo samples in the SGVB
            """
            preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = self._inference(x)
            # mu, sigma from the decoder
            eps = 1e-10
            L_rec = 0.0
            for l in range(self.L):
                z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu  # shape [batch_size, self.zd_dim]4
                if len(inject_domain) > 0:
                    zy = torch.cat((z, inject_domain), 1)
                else:
                    zy = z
                x_pro, log_sigma = self.decoder(zy)  # x_pro, mu, sigma
                L_rec += self._cal_reconstruction_loss_helper(x, x_pro, log_sigma)  # FIXME
            L_rec /= self.L
            Loss = L_rec * x.size(1)
            # --> this is the -"first line" of eq (12) in the paper with additional averaging over the batch.
            Loss += (
                0.5
                * warmup_beta
                * torch.mean(
                    torch.sum(
                        probs
                        * torch.sum(
                            log_sigma2_c.unsqueeze(0)
                            + torch.exp(z_sigma2_log.unsqueeze(1) - log_sigma2_c.unsqueeze(0))
                            + (z_mu.unsqueeze(1) - mu_c.unsqueeze(0)).pow(2) / torch.exp(log_sigma2_c.unsqueeze(0)),
                            2,
                        ),
                        1,
                    )
                )
            )
            # inner sum dimentions:
            # [1, d_dim, zd_dim] + exp([batch_size, 1, zd_dim] - [1, d_dim, zd_dim]) + ([batch_size, 1, zd_dim] - [1, d_dim, zd_dim])^2 / exp([1, d_dim, zd_dim])
            # = [batch_size, d_dim, zd_dim] -> sum of zd_dim dimensions
            # the next sum is over d_dim dimensions
            # the mean is over the batch
            # --> overall, this is -"second line of eq. (12)" with additional mean over the batch
            Loss -= warmup_beta * torch.mean(
                torch.sum(probs * torch.log(pi.unsqueeze(0) / (probs + eps)), 1)
            )  # FIXME: (+eps) is a hack to avoid NaN. Is there a better way?
            # dimensions: [batch_size, d_dim] * log([1, d_dim] / [batch_size, d_dim]), where the sum is over d_dim dimensions --> [batch_size] --> mean over the batch --> a scalar
            Loss -= 0.5 * warmup_beta * torch.mean(torch.sum(1.0 + z_sigma2_log, 1))
            # dimensions: mean( sum( [batch_size, zd_dim], 1 ) ) where the sum is over zd_dim dimensions and mean over the batch
            # --> overall, this is -"third line of eq. (12)" with additional mean over the batch
            return Loss
        def gaussian_pdfs_log(self, x, mus, log_sigma2s):
            """helper function"""
            loglik = []
            for c in range(self.d_dim):
                loglik.append(self.gaussian_pdf_log(x, mus[c, :], log_sigma2s[c, :]).view(-1, 1))
            return torch.cat(loglik, 1)
        @staticmethod
        def gaussian_pdf_log(x, mu, log_sigma2):
            """
            subhelper function just one gausian pdf log calculation, used as a basis for gaussian_pdfs_log function
            :param x: tensor of shape [batch_size, self.zd_dim]
            :param mu: mean for the cluster distribution
            :param log_sigma2: variance parameters of the cluster distribtion
            :return: tensor with the Gaussian log probabilities of the shape of [batch_size, 1]
            """
            return -0.5 * (
                torch.sum(
                    np.log(np.pi * 2) + log_sigma2 + (x - mu).pow(2) / torch.exp(log_sigma2),
                    1,
                )
            )
    return ModelVaDE
 
[docs]def test_fun(d_dim, zd_dim, device):
    device = torch.device("cpu")
    model = ModelVaDE(d_dim=d_dim, zd_dim=zd_dim, device=device)
    x = torch.rand(2, 3, 28, 28)
    import numpy as np
    a = np.zeros((2, 10))
    a = np.double(a)
    a[0, 1] = 1.0
    a[1, 8] = 1.0
    a
    y = torch.tensor(a, dtype=torch.float)
    model(x, y)
    model.cal_loss(x)