Source code for domid.models.model_sdcn

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from domainlab.utils.utils_classif import logit2preds_vpic

from domid.compos.cnn_AE import ConvolutionalDecoder, ConvolutionalEncoder
from domid.compos.GNN import GNN
from domid.compos.linear_AE import LinearDecoderAE, LinearEncoderAE
from domid.models.a_model_cluster import AModelCluster


[docs]def mk_sdcn(parent_class=AModelCluster): class ModelSDCN(parent_class): """ ModelSDCN is a class that implements the SDCN model.(Bo D et al. 2020) The model is composed of a convolutional encoder and decoder, a GNN and a clustering layer. """ def __init__( self, zd_dim, d_dim, device, i_c, i_h, i_w, bs, task, L=5, random_batching=False, model_method="cnn", prior="Bern", dim_inject_y=0, pre_tr_weight_path=None, feat_extract="vae", graph_method=None, ): super(ModelSDCN, self).__init__() self.zd_dim = zd_dim self.d_dim = d_dim self.device = device self.L = L self.loss_epoch = 0 self.dim_inject_y = dim_inject_y self.prior = prior self.model_method = model_method self.random_batching = random_batching self.pre_tr_weight_path = pre_tr_weight_path self.feat_extract = feat_extract self.task = task self.model = "sdcn" # if self.args.dim_inject_y: # self.dim_inject_y = self.args.dim_inject_y n_clusters = d_dim n_z = zd_dim n_input = i_c * i_h * i_w self.cluster_layer = nn.Parameter(torch.Tensor(self.d_dim, self.zd_dim)) torch.nn.init.xavier_normal_(self.cluster_layer.data) if self.model_method == "linear": n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, = ( 500, 500, 2000, 2000, 500, 500, ) self.encoder = LinearEncoderAE(n_enc_1, n_enc_2, n_enc_3, n_input, n_z) self.decoder = LinearDecoderAE(n_dec_1, n_dec_2, n_dec_3, n_input, n_z) else: self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device) self.decoder = ConvolutionalDecoder( prior=prior, zd_dim=zd_dim, # 50 domain_dim=self.dim_inject_y, h_dim=self.encoder.h_dim, num_channels=i_c, ).to(device) n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, = ( int((i_w / 2) ** 2 * self.encoder.num_filters[0]), int((i_w / 4) ** 2 * self.encoder.num_filters[1]), int((i_w / 8) ** 2 * self.encoder.num_filters[2]), int((i_w / 8) ** 2 * self.encoder.num_filters[2]), int((i_w / 4) ** 2 * self.encoder.num_filters[1]), int((i_w / 2) ** 2 * self.encoder.num_filters[0]), ) print("Filter sizes for GNN", n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3) if self.pre_tr_weight_path: self.encoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "encoder.pt"), map_location=self.device) ) self.decoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "decoder.pt"), map_location=self.device) ) print("Pre-trained weights loaded") else: raise ValueError("Pre-trianed weight path is not provided") self.gnn_model = GNN(n_input, n_enc_1, n_enc_2, n_enc_3, n_z, n_clusters, device) if torch.cuda.device_count() > 1: # Check if multiple GPUs are available print("Using DataParallel with {} GPUs.".format(torch.cuda.device_count())) # self.gnn_model = torch.nn.DataParallel(self.gnn_model) # Wrap the model for DataParallel self.encoder = torch.nn.DataParallel(self.encoder) self.decoder = torch.nn.DataParallel(self.decoder) else: if torch.cuda.device_count() == 1: print("Using a single GPU.") else: print("Using CPU(s).") self.v = 1.0 self.counter = 0 self.q_activation = torch.zeros((10, 100)) self.kl_loss_running = 0 self.re_loss_running = 0 self.ce_loss_running = 0 if "mnist" in self.task: self.graph_method = "heat" if "wsi" in self.task: self.graph_method = "patch_distance" if graph_method is not None: self.graph_method = graph_method if self.task == "wsi": self.random_ind = [torch.randint(0, self.bs, (int(self.bs / 3),)) for i in range(0, 66)] else: self.random_ind = [] def _inference(self, x, inject_tensor=None): """ :param x: [batch_size, n_channels, height, width] :return: - probs_c - [batch_size, n_clusters] - q - [batch_size, n_clusters] - z - [batch_size, n_z] - z_mu - [batch_size, n_z] - z_sigma2_log - [batch_size, n_z] - pi - [batch_size, n_clusters] - logits - [batch_size, n_clusters] """ if self.model_method == "linear": x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) enc_h1, enc_h2, enc_h3, z = self.encoder(x) h = self.gnn_model(x, self.adj.to(self.device), enc_h1, enc_h2, enc_h3, z) probs_c = F.softmax( h, dim=1 ) # [batch_size, n_clusters] (batch_zise==number of samples) same as preds in the code # and p is calculated using preds and target distribution. # Dual Self-supervised Module q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2)) / self.v q = q.pow((self.v + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() logits = q.type(torch.float32) # q in the paper and code preds_c, *_ = logit2preds_vpic(h) # probs_c is F.softmax(logit, dim=1) return preds_c, probs_c, z, logits def infer_d_v_2(self, x): """ Used for tensorboard visualizations only. """ # import pdb; pdb.set_trace() results = self._inference(x) z = results[2] # print(results[2].shape, inject_domain.shape, zy.shape) x_pro = self.decoder(z) preds_c, probs_c, z, logits = (r.cpu().detach() for r in results) return preds_c, z, probs_c, x_pro def target_distribution(self, q): """ Compute the target distribution p, where p_i = (sum_j q_j)^2 / sum_j^2 q_j. Corresponds to equation (12) from the paper. """ weight = q**2 / q.sum(0) return (weight.t() / weight.sum(1)).t() def _cal_kl_loss(self, x, inject_tensor=None, warmup_beta=None): """ Compute the loss of the model. Concentrate two different objectives, i.e. clustering objective and classification objective, in one loss function. Corresponds to equation (15) in the paper. :param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim]. :param float warmup_beta: Warmup coefficient for the KL divergence term. :return tensor loss: Loss tensor. """ preds_c, probs_c, z, logits = self._inference(x) # logits is q in the paper # probs_c is pred in the code q = logits pred = probs_c x_bar = self.decoder(z) q = q.data # if self.counter==1: p = self.target_distribution(q) # self.local_tb.add_histogram('p', p, self.counter) if self.model == "linear": x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) kl_loss = F.kl_div(q.log(), p, reduction="batchmean") ce_loss = F.kl_div(pred.log(), p, reduction="batchmean") re_loss = F.mse_loss(x, x_bar) loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss self.kl_loss_running = kl_loss self.ce_loss_running = ce_loss self.re_loss_running = re_loss return loss.type(torch.double) def cal_loss_for_tensorboard(self): return self.kl_loss_running, self.ce_loss_running, self.re_loss_running def hyper_init(self, functor_scheduler): """hyper_init. :param functor_scheduler: """ return functor_scheduler(trainer=None) def hyper_update(self, epoch, fun_scheduler): """hyper_update. :param epoch: :param fun_scheduler: the hyperparameter scheduler object """ dict_rst = fun_scheduler(epoch) # the __call__ method of hy # perparameter scheduler self.alpha = dict_rst["alpha"] return ModelSDCN