Source code for domid.models.a_model_cluster

import abc

import torch
import torch.nn as nn
import torch.nn.functional as F

from domid.utils.perf_cluster import PerfCluster


[docs]class AModelCluster(nn.Module): """ Operations that all clustering models should have """
[docs] def __init__(self): super(AModelCluster, self).__init__() self._decoratee = None # FIXME do i pass it to every model?
[docs] def create_perf_obj(self, task): """ Sets up the performance metrics used. """ self.task = task self.perf_metric = PerfCluster(task.dim_y) return self.perf_metric
[docs] def cal_perf_metric(self, loader_tr, device, loader_te=None): """ Clustering performance metric on the training and test/validation sets. """ metric_te = None metric_tr = None with torch.no_grad(): metric_tr = self.perf_metric.cal_acc(self, loader_tr, device) if loader_te is not None: metric_te = self.perf_metric.cal_acc(self, loader_te, device) r_score_tr = None r_score_te = None # if self.task.get_list_domains() == ['class0', 'class1', 'class2']: #if task ==her2 if hasattr(self.task, "calc_corr"): with torch.no_grad(): r_score_tr, r_score_te = self.task.calc_corr(self, loader_tr, loader_te, device) return metric_tr, metric_te, r_score_tr, r_score_te
[docs] def cal_loss(self, tensor_x, inj_tensor=torch.Tensor([]), warmup_beta=None): """ Calculates the loss for the model. """ total_loss = self._cal_reconstruction_loss(tensor_x, inj_tensor) # if self._decoratee is not None: kl_loss = self._cal_kl_loss(tensor_x, inj_tensor) total_loss += kl_loss return total_loss
[docs] def infer_d_v(self, x): """ Predict the cluster/domain of the input data. Corresponds to equation (16) in the paper. :param tensor x: Input tensor of a shape [batchsize, 3, horzintal dim, vertical dim]. :return tensor preds: One hot encoded tensor of the predicted cluster assignment. """ preds, *_ = self._inference(x) return preds.cpu().detach()
[docs] def extend(self, model): """ extend the loss of the decoratee """ self._decoratee = model
def _extend_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ combine losses from two models """ if self._decoratee is not None: return self._decoratee._cal_kl_loss(tensor_x, tensor_y, tensor_d, others) return None, None @abc.abstractmethod def _cal_pretrain_loss(self, tensor_x, inject_tensor=torch.Tensor([])): """ Pretraining loss for the model. """ return self._cal_reconstruction_loss(tensor_x, inject_tensor) def _cal_reconstruction_loss(self, tensor_x, inject_domain=torch.Tensor([])): if self.model_method == "linear": tensor_x = torch.reshape(tensor_x, (tensor_x.shape[0], -1)) z = self.encoder.get_z(tensor_x) if len(inject_domain) > 0: zy = torch.cat((z, inject_domain), 1) else: zy = z x_pro = self.decoder(zy) if isinstance(x_pro, tuple): x_pro = x_pro[0] loss = F.mse_loss(x_pro, tensor_x) return loss @abc.abstractmethod def _cal_kl_loss(self, q, p): # FIXME KL loss is different for each of the model, redefined it in every model? return NotImplementedError