Source code for domid.models.model_m2yd

import torch
import torch.distributions as dist
import torch.nn.functional as F
from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import (
    DecoderConcatLatentFCReshapeConvGatedConv,
)
from domainlab.compos.vae.compos.encoder import LSEncoderLinear as LSEncoderDense
from domainlab.models.a_model_classif import AModelClassif
from domainlab.utils.utils_class import store_args
from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic

from domid.compos.nn_net import Net_MNIST
from domid.utils.perf_cluster import PerfCluster


[docs]def mk_m2yd(parent_class=AModelClassif): class ModelXY2D(AModelClassif): """ Let zd to be continuous vector, each component of zd represents the "attention" weight. For a cluster, that means the bigger the $zd_k$ value, the more likely the cluster assignment to component $k$. Note $zd_k~N(0,1)$. Computational Structure: generative path: (zd,y) -> x (image) generative path: N(0,I) -> zd (prior for zd) variational posterior path: 1. x -> y, 2. [y, feat(x)] -> z_d FIXME: if we change the variational inference order, instead of x->y, then [y,feat(x)]->z_d if we do first x->d (style extraction, texture prediction), then [feat(x),d]-> y, will this be better? KL divergence between posterior path vs generative(prior) path: 1. x->y: auxiliary path (not regularized by generative path, but by supervised learning) no KL for y 2. KL(q(z_d)|p(z_d)), q(z_d): [y,feat(x)]-> z_d p(z_d): N(0,I) Auxilliary path: supervised learning of x->y FIXME: original M2 has prior Gaussian(0, I) for $z_d$, will this hinder learning of $z_d$ on each component since the prior is draging each component to zero. """ @store_args def __init__(self, list_str_y, y_dim, zd_dim, gamma_y, device, i_c, i_h, i_w, dim_feat_x=10): """ :param y_dim: classification task class-label dimension :param zd_dim: dimension of latent variable $z_d$ dimension :param aux_y: """ super().__init__(list_str_y=list_str_y, net_classifier=None) self.infer_y_from_x = Net_MNIST(y_dim, self.i_h) self._net_classifier = self.infer_y_from_x # FIXME: this is a hack, and may not be needed at all self.d_dim = zd_dim # number of domains self.feat_x2concat_y = Net_MNIST(self.dim_feat_x, self.i_h) # FIXME: shall we share parameters between infer_y_from_x and self.feat_x2concat_y? self.infer_domain = LSEncoderDense(z_dim=self.zd_dim, dim_input=self.dim_feat_x + self.y_dim) self.gamma_y = gamma_y # LN: location scale encoder self.decoder = DecoderConcatLatentFCReshapeConvGatedConv( z_dim=zd_dim + y_dim, i_c=self.i_c, i_w=self.i_w, i_h=self.i_h ) def cal_logit_y(self, tensor_x): """ calculate the logit for softmax classification """ return self.infer_y_from_x(tensor_x) def infer_y_vpicn(self, tensor_x): """ :param tensor_x: input tensor :return: - vec_one_hot - (list) one-hot encoded classification output - prob - (list) softmax probabilities per class - ind - (int) index of maximal output class score - confidence - (float) maximum probability (already included in prob) - na_class - (string) class label for the maximum probability class """ with torch.no_grad(): logit_y = self.infer_y_from_x(tensor_x) vec_one_hot, prob, ind, confidence = logit2preds_vpic(logit_y) na_class = get_label_na(ind, self.list_str_y) return vec_one_hot, prob, ind, confidence, na_class def infer_d_v(self, tensor_x): with torch.no_grad(): y_hat_logit = self.infer_y_from_x(tensor_x) feat_x = self.feat_x2concat_y(tensor_x) feat_y_x = torch.cat((y_hat_logit, feat_x), dim=1) q_zd, zd_q = self.infer_domain(feat_y_x) vec_one_hot, *_ = logit2preds_vpic(q_zd.mean) return vec_one_hot def forward(self, tensor_x, vec_y, vec_d=None): y_hat_logit = self.infer_y_from_x(tensor_x) feat_x = self.feat_x2concat_y(tensor_x) feat_y_x = torch.cat((y_hat_logit, feat_x), dim=1) q_zd, zd_q = self.infer_domain(feat_y_x) return q_zd, zd_q, y_hat_logit def cal_loss(self, x, y, d=None, others=None): q_zd, zd_q, y_hat = self.forward(x, y) z_con = torch.cat((zd_q, y), dim=1) # FIXME: pay attention to order nll, x_mean, x_logvar = self.decoder(z_con, x) pzd_loc = torch.zeros(1, self.zd_dim).to(self.device) pzd_scale = torch.ones(self.zd_dim).to(self.device) pzd = dist.Normal(pzd_loc, pzd_scale) pzd_log_prob = pzd.log_prob(zd_q) qzd_log_prob = q_zd.log_prob(zd_q) zd_p_minus_zd_q = torch.sum(pzd_log_prob - qzd_log_prob, dim=1) # FIXME: use analytical expression since using sampling to estimate KL divergence is high variance _, y_target = y.max(dim=1) # y is the observed class label, not the cluster label! lc_y = F.cross_entropy(y_hat, y_target, reduction="none") loss = nll - zd_p_minus_zd_q + self.gamma_y * lc_y return loss.mean() def cal_perf_metric(self, loader_tr, device, loader_te=None): """ Clustering performance metric on the training and test/validation sets. """ self.perf_metric = PerfCluster( self.d_dim ) # FIXME: this is a hack because self.perf_metric is actually defined somewhere else (not in this file) metric_te = None metric_tr = None with torch.no_grad(): metric_tr = self.perf_metric.cal_acc(self, loader_tr, device) if loader_te is not None: metric_te = self.perf_metric.cal_acc(self, loader_te, device) return metric_tr, metric_te, None, None # FIXME: the None values are a hack, because of the clusteringOnly observer used with M2YD; eventually, need to update M2YD model to use a specializaed observer (probably the clustering observer, which needs refactoring, as opposed to the clusteringOnly one) return ModelXY2D
[docs]def test_fun(): model = ModelXY2D(y_dim=10, zd_dim=8, gamma_y=3500, device=torch.device("cpu"), i_c=3, i_h=28, i_w=28) device = torch.device("cpu") x = torch.rand(2, 3, 28, 28) import numpy as np a = np.zeros((2, 10)) a = np.double(a) a[0, 1] = 1.0 a[1, 8] = 1.0 a y = torch.tensor(a, dtype=torch.float) model(x, y) model.cal_loss(x, y)