Source code for domid.models.model_dec

import os

import torch
import torch.nn as nn
from domainlab.utils.utils_classif import logit2preds_vpic

from domid.compos.DEC_clustering_layer import DECClusteringLayer
from domid.models.a_model_cluster import AModelCluster


[docs]def mk_dec(parent_class=AModelCluster): class ModelDEC(parent_class): def __init__( self, zd_dim, d_dim, device, i_c, i_h, i_w, bs, L=5, random_batching=False, model_method="cnn", prior="Bern", dim_inject_y=0, pre_tr_weight_path=None, feat_extract="vae", ): """ DEC model (Xie et al. 2015 "Unsupervised Deep Embedding for Clustering Analysis") with fully connected encoder and decoder. :param zd_dim: dimension of the latent space :param d_dim: number of clusters for the clustering task :param device: device to use, e.g., "cuda" or "cpu" :param i_c: number of channels of the input image :param i_h: height of the input image :param i_w: width of the input image :param args: command line arguments """ super(ModelDEC, self).__init__() self.n_clusters = d_dim self.d_dim = d_dim self.zd_dim = zd_dim self.device = device self.alpha = 1 # FIXME self.hidden = zd_dim self.cluster_centers = None self.warmup_beta = 0 self.dim_inject_y = dim_inject_y self.model_method = model_method self.prior = prior self.feat_extract = feat_extract self.random_batching = random_batching self.pre_tr_weight_path = pre_tr_weight_path self.model = "dec" if self.feat_extract == "vae": from domid.compos.cnn_VAE import ConvolutionalDecoder, ConvolutionalEncoder elif self.feat_extract == "ae": from domid.compos.cnn_AE import ConvolutionalDecoder, ConvolutionalEncoder self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device) self.decoder = ConvolutionalDecoder( prior=prior, zd_dim=zd_dim, # 50 domain_dim=self.dim_inject_y, h_dim=self.encoder.h_dim, num_channels=i_c, ).to(device) self.autoencoder = self.encoder self.clusteringlayer = DECClusteringLayer( self.n_clusters, self.hidden, None, self.alpha, self.device ) # learnable parameter - cluster center self.mu_c = self.clusteringlayer.cluster_centers self.log_pi = nn.Parameter( torch.FloatTensor( self.d_dim, ) .fill_(1.0 / self.d_dim) .log(), requires_grad=True, ) if self.pre_tr_weight_path: self.encoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "encoder.pt"), map_location=self.device) ) self.decoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "decoder.pt"), map_location=self.device) ) print("AE is using pretrained weights. No need for pretraining epochs. ") self.log_sigma2_c = nn.Parameter(torch.FloatTensor(self.d_dim, self.zd_dim).fill_(0), requires_grad=True) self.random_ind = [] def target_distribution(self, q_): """ Corresponds to equation 3 from the paper. Calculates the target distribution for the Kullback-Leibler divergence loss. :param q_: A tensor of the predicted cluster probabilities. :return tensor: The calculated target distribution """ weight = (q_**2) / torch.sum(q_, 0) return (weight.t() / torch.sum(weight, 1)).t() def _inference(self, x): """ :return tensor preds_c: One hot encoded tensor of the predicted cluster assignment (shape: [batch_size, self.d_dim]). :return tensor probs_c: Tensor of the predicted cluster probabilities; this is q(c|x) per eq. (16) or gamma_c in eq. (12) (shape: [batch_size, self.d_dim]). :return tensor z: Tensor of the latent space representation (shape: [batch_size, self.zd_dim]) :return tensor z_mu: Tensor of the mean of the latent space representation (shape: [batch_size, self.zd_dim]) :return tensor z_sigma2_log: Tensor of the log of the variance of the latent space representation (shape: [batch_size, self.zd_dim]) :return tensor mu_c: Tensor of the estimated cluster means (shape: [self.d_dim, self.zd_dim]) :return tensor log_sigma2_c: Tensor of the estimated cluster variances (shape: [self.d_dim, self.zd_dim]) :return tensor pi: Tensor of the estimated cluster prevalences, p(c) (shape: [self.d_dim]) :return tensor logits: Tensor where each column contains the log-probability p(c)p(z|c) for cluster c=0,...,self.d_dim-1 (shape: [batch_size, self.d_dim]). """ z_mu = self.encoder.get_z(x) z = z_mu z_sigma2_log = self.encoder.get_log_sigma2(x) probs_c = self.clusteringlayer(z_mu) # in dec it is preds_c, logits, *_ = logit2preds_vpic(probs_c) # preds c is oen hot encoded mu_c = self.mu_c # print(mu_c[0, :5]) log_sigma2_c = self.log_sigma2_c pi = self.log_pi return preds_c, probs_c, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits def infer_d_v_2(self, x, inject_domain): results = self._inference(x) if len(inject_domain) > 0: zy = torch.cat((results[2], inject_domain), 1) else: zy = results[2] x_pro, *_ = self.decoder(zy) preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = ( r.cpu().detach() if r is not None else None for r in results ) return preds, z_mu, z, log_sigma2_c, probs, x_pro def _cal_kl_loss(self, x, inject_tensor, warmup_beta=0.1): """ Calculates the KL-divergence loss between the predicted probabilities and the target distribution. :param x: input tensor/image :param inject_tensor: tensor to inject (not used in DEC, only used in CDVaDE :param warmup_beta: warm-up beta value :return tensor loss (float): calculated KL-divergence loss value """ preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = self._inference(x) target = self.target_distribution(probs).detach() loss_function = nn.KLDivLoss(reduction="batchmean") loss = loss_function(probs.log(), target) if self.warmup_beta != warmup_beta: print(logits[0, :], target[0, :]) print(loss) self.warmup_beta = warmup_beta return loss return ModelDEC