import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from domainlab.utils.utils_classif import logit2preds_vpic
from domid.compos.cnn_AE import ConvolutionalDecoder, ConvolutionalEncoder
from domid.compos.GNN import GNN
from domid.compos.linear_AE import LinearDecoderAE, LinearEncoderAE
from domid.models.a_model_cluster import AModelCluster
[docs]def mk_sdcn(parent_class=AModelCluster):
    class ModelSDCN(parent_class):
        """
        ModelSDCN is a class that implements the SDCN model.(Bo D et al. 2020)
        The model is composed of a convolutional encoder and decoder, a GNN and a clustering layer.
        """
        def __init__(
            self,
            zd_dim,
            d_dim,
            device,
            i_c,
            i_h,
            i_w,
            bs,
            task,
            L=5,
            random_batching=False,
            model_method="cnn",
            prior="Bern",
            dim_inject_y=0,
            pre_tr_weight_path=None,
            feat_extract="vae",
            graph_method=None,
        ):
            super(ModelSDCN, self).__init__()
            self.zd_dim = zd_dim
            self.d_dim = d_dim
            self.device = device
            self.L = L
            self.loss_epoch = 0
            self.dim_inject_y = dim_inject_y
            self.prior = prior
            self.model_method = model_method
            self.random_batching = random_batching
            self.pre_tr_weight_path = pre_tr_weight_path
            self.feat_extract = feat_extract
            self.task = task
            self.model = "sdcn"
            # if self.args.dim_inject_y:
            #     self.dim_inject_y = self.args.dim_inject_y
            n_clusters = d_dim
            n_z = zd_dim
            n_input = i_c * i_h * i_w
            self.cluster_layer = nn.Parameter(torch.Tensor(self.d_dim, self.zd_dim))
            torch.nn.init.xavier_normal_(self.cluster_layer.data)
            if self.model_method == "linear":
                n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, = (
                    500,
                    500,
                    2000,
                    2000,
                    500,
                    500,
                )
                self.encoder = LinearEncoderAE(n_enc_1, n_enc_2, n_enc_3, n_input, n_z)
                self.decoder = LinearDecoderAE(n_dec_1, n_dec_2, n_dec_3, n_input, n_z)
            else:
                self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device)
                self.decoder = ConvolutionalDecoder(
                    prior=prior,
                    zd_dim=zd_dim,  # 50
                    domain_dim=self.dim_inject_y,
                    h_dim=self.encoder.h_dim,
                    num_channels=i_c,
                ).to(device)
                n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, = (
                    int((i_w / 2) ** 2 * self.encoder.num_filters[0]),
                    int((i_w / 4) ** 2 * self.encoder.num_filters[1]),
                    int((i_w / 8) ** 2 * self.encoder.num_filters[2]),
                    int((i_w / 8) ** 2 * self.encoder.num_filters[2]),
                    int((i_w / 4) ** 2 * self.encoder.num_filters[1]),
                    int((i_w / 2) ** 2 * self.encoder.num_filters[0]),
                )
                print("Filter sizes for GNN", n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3)
            if self.pre_tr_weight_path:
                self.encoder.load_state_dict(
                    torch.load(os.path.join(self.pre_tr_weight_path, "encoder.pt"), map_location=self.device)
                )
                self.decoder.load_state_dict(
                    torch.load(os.path.join(self.pre_tr_weight_path, "decoder.pt"), map_location=self.device)
                )
                print("Pre-trained weights loaded")
            else:
                raise ValueError("Pre-trianed weight path is not provided")
            self.gnn_model = GNN(n_input, n_enc_1, n_enc_2, n_enc_3, n_z, n_clusters, device)
            if torch.cuda.device_count() > 1:  # Check if multiple GPUs are available
                print("Using DataParallel with {} GPUs.".format(torch.cuda.device_count()))
                # self.gnn_model = torch.nn.DataParallel(self.gnn_model)  # Wrap the model for DataParallel
                self.encoder = torch.nn.DataParallel(self.encoder)
                self.decoder = torch.nn.DataParallel(self.decoder)
            else:
                if torch.cuda.device_count() == 1:
                    print("Using a single GPU.")
                else:
                    print("Using CPU(s).")
            self.v = 1.0
            self.counter = 0
            self.q_activation = torch.zeros((10, 100))
            self.kl_loss_running = 0
            self.re_loss_running = 0
            self.ce_loss_running = 0
            if "mnist" in self.task:
                self.graph_method = "heat"
            if "wsi" in self.task:
                self.graph_method = "patch_distance"
            if graph_method is not None:
                self.graph_method = graph_method
            if self.task == "wsi":
                self.random_ind = [torch.randint(0, self.bs, (int(self.bs / 3),)) for i in range(0, 66)]
            else:
                self.random_ind = []
        def _inference(self, x, inject_tensor=None):
            """
            :param x: [batch_size, n_channels, height, width]
            :return:
                - probs_c - [batch_size, n_clusters]
                - q - [batch_size, n_clusters]
                - z - [batch_size, n_z]
                - z_mu - [batch_size, n_z]
                - z_sigma2_log - [batch_size, n_z]
                - pi - [batch_size, n_clusters]
                - logits - [batch_size, n_clusters]
            """
            if self.model_method == "linear":
                x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
            enc_h1, enc_h2, enc_h3, z = self.encoder(x)
            h = self.gnn_model(x, self.adj.to(self.device), enc_h1, enc_h2, enc_h3, z)
            probs_c = F.softmax(
                h, dim=1
            )  # [batch_size, n_clusters] (batch_zise==number of samples) same as preds in the code
            # and p is calculated using preds and target distribution.
            # Dual Self-supervised Module
            q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2)) / self.v
            q = q.pow((self.v + 1.0) / 2.0)
            q = (q.t() / torch.sum(q, 1)).t()
            logits = q.type(torch.float32)  # q in the paper and code
            preds_c, *_ = logit2preds_vpic(h)  # probs_c is F.softmax(logit, dim=1)
            return preds_c, probs_c, z, logits
        def infer_d_v_2(self, x):
            """
            Used for tensorboard visualizations only.
            """
            # import pdb; pdb.set_trace()
            results = self._inference(x)
            z = results[2]
            # print(results[2].shape, inject_domain.shape, zy.shape)
            x_pro = self.decoder(z)
            preds_c, probs_c, z, logits = (r.cpu().detach() for r in results)
            return preds_c, z, probs_c, x_pro
        def target_distribution(self, q):
            """
            Compute the target distribution p, where p_i = (sum_j q_j)^2 / sum_j^2 q_j.
            Corresponds to equation (12) from the paper.
            """
            weight = q**2 / q.sum(0)
            return (weight.t() / weight.sum(1)).t()
        def _cal_kl_loss(self, x, inject_tensor=None, warmup_beta=None):
            """
            Compute the loss of the model.
            Concentrate two different objectives, i.e. clustering objective and classification objective, in one loss function.
            Corresponds to equation (15) in the paper.
            :param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim].
            :param float warmup_beta: Warmup coefficient for the KL divergence term.
            :return tensor loss: Loss tensor.
            """
            preds_c, probs_c, z, logits = self._inference(x)
            # logits is q in the paper
            # probs_c is pred in the code
            q = logits
            pred = probs_c
            x_bar = self.decoder(z)
            q = q.data
            # if self.counter==1:
            p = self.target_distribution(q)
            # self.local_tb.add_histogram('p', p, self.counter)
            if self.model == "linear":
                x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
            kl_loss = F.kl_div(q.log(), p, reduction="batchmean")
            ce_loss = F.kl_div(pred.log(), p, reduction="batchmean")
            re_loss = F.mse_loss(x, x_bar)
            loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss
            self.kl_loss_running = kl_loss
            self.ce_loss_running = ce_loss
            self.re_loss_running = re_loss
            return loss.type(torch.double)
        def cal_loss_for_tensorboard(self):
            return self.kl_loss_running, self.ce_loss_running, self.re_loss_running
        def hyper_init(self, functor_scheduler):
            """hyper_init.
            :param functor_scheduler:
            """
            return functor_scheduler(trainer=None)
        def hyper_update(self, epoch, fun_scheduler):
            """hyper_update.
            :param epoch:
            :param fun_scheduler: the hyperparameter scheduler object
            """
            dict_rst = fun_scheduler(epoch)  # the __call__ method of hy
            # perparameter scheduler
            self.alpha = dict_rst["alpha"]
    return ModelSDCN