Source code for domid.models.model_ae

import os

import torch
import torch.nn as nn
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.utils.utils_classif import logit2preds_vpic
from sklearn.cluster import KMeans

from domid.compos.cnn_AE import ConvolutionalDecoder, ConvolutionalEncoder
from domid.compos.linear_AE import LinearDecoderAE, LinearEncoderAE
from domid.models.a_model_cluster import AModelCluster


[docs]def mk_ae(parent_class=AModelCluster): class ModelAE(parent_class): def __init__( self, zd_dim, d_dim, device, i_c, i_h, i_w, bs, L=5, random_batching=False, model_method="cnn", prior="Bern", dim_inject_y=0, pre_tr_weight_path=None, feat_extract="vae", ): super(ModelAE, self).__init__() self.zd_dim = zd_dim self.d_dim = d_dim self.device = device self.L = L self.loss_epoch = 0 self.batch_zero = True self.dim_inject_y = 0 self.dim_inject_y = dim_inject_y self.model_method = model_method self.prior = prior self.feat_extract = feat_extract self.random_batching = random_batching self.pre_tr_weight_path = pre_tr_weight_path self.model = "ae" n_z = zd_dim n_input = i_c * i_h * i_w n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, = ( 500, 500, 2000, 2000, 500, 500, ) self.cluster_layer = nn.Parameter(torch.Tensor(self.d_dim, self.zd_dim)) torch.nn.init.xavier_normal_(self.cluster_layer.data) if self.model_method == "linear": self.encoder = LinearEncoderAE(n_enc_1, n_enc_2, n_enc_3, n_input, n_z) self.decoder = LinearDecoderAE(n_dec_1, n_dec_2, n_dec_3, n_input, n_z) else: self.encoder = ConvolutionalEncoder(zd_dim=zd_dim, num_channels=i_c, i_w=i_w, i_h=i_h).to(device) self.decoder = ConvolutionalDecoder( prior=prior, zd_dim=zd_dim, # 50 domain_dim=self.dim_inject_y, # # domain_dim=self.dim_inject_y, h_dim=self.encoder.h_dim, num_channels=i_c, ).to(device) if self.pre_tr_weight_path: self.encoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "encoder.pt"), map_location=self.device) ) self.decoder.load_state_dict( torch.load(os.path.join(self.pre_tr_weight_path, "decoder.pt"), map_location=self.device) ) print("Pre-trained weights loaded") self.counter = 0 self.random_ind = [] def distance_between_clusters(self, cluster_layer): pairwise_dist = torch.zeros(cluster_layer.shape[0], cluster_layer.shape[0]) for i in range(0, cluster_layer.shape[0]): for j in range(0, cluster_layer.shape[0]): pairwise_dist[i, j] = torch.cdist( cluster_layer[i, :].unsqueeze(0).unsqueeze(0), cluster_layer[j, :].unsqueeze(0).unsqueeze(0) ) return pairwise_dist def _inference(self, x, inject_tensor=None): if self.model_method == "linear": x = torch.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) enc_h1, enc_h2, enc_h3, z = self.encoder(x) # _, _, z, *_ = self._inference(x) kmeans = KMeans(n_clusters=self.d_dim) kmeans.fit_predict(z.detach().cpu().numpy()) # x_bar, *_ = self.decoder(z) z_mu = torch.mean(z, dim=0) z_sigma2_log = torch.std(z, dim=0) pi = torch.Tensor([0]) predictions = kmeans.labels_ preds_c = mk_fun_label2onehot(self.d_dim)(predictions) logits = torch.Tensor(kmeans.fit_transform(z.detach().cpu().numpy())).to(self.device) _, probs_c, *_ = logit2preds_vpic(logits) cluster_layer = torch.tensor(kmeans.cluster_centers_) return preds_c, probs_c, z, z_mu, z_sigma2_log, z_mu, z_sigma2_log, pi, logits def infer_d_v_2(self, x, inject_domain): """ Used for tensorboard visualizations only. """ results = self._inference(x) if len(inject_domain) > 0: zy = torch.cat((results[2], inject_domain), 1) else: zy = results[2] x_pro = self.decoder(zy) preds, probs, z, z_mu, z_sigma2_log, mu_c, log_sigma2_c, pi, logits = (r.cpu().detach() for r in results) return preds, z_mu, z, log_sigma2_c, probs, x_pro def cal_loss(self, x, inject_domain, warmup_beta=None): loss = self._cal_pretrain_loss(x, inject_domain) return loss return ModelAE
[docs]def test_fun(d_dim, zd_dim, device): device = torch.device("cpu") model = ModelAE(d_dim=d_dim, zd_dim=zd_dim, device=device) x = torch.rand(2, 3, 28, 28) import numpy as np a = np.zeros((2, 10)) a = np.double(a) a[0, 1] = 1.0 a[1, 8] = 1.0 a y = torch.tensor(a, dtype=torch.float) model(x, y) model.cal_loss(x)