Source code for domid.compos.linear_VAE

import numpy as np
import torch
import torch.nn as nn

from domid.compos.VAE_blocks import linear_block


[docs]class LinearEncoder(nn.Module):
[docs] def __init__(self, zd_dim, input_dim=(3, 28, 28), features_dim=[500, 500, 2000]): """ VAE Encoder with linear layers :param zd_dim: dimension of the latent space :param input_dim: dimensions of the input, e.g., (3, 28, 28) for MNIST in RGB format :param features_dim: list of dimensions of the hidden layers """ super(LinearEncoder, self).__init__() self.input_dim = np.prod(input_dim) self.encod = nn.Sequential( *linear_block(self.input_dim, features_dim[0]), *linear_block(features_dim[0], features_dim[1]), *linear_block(features_dim[1], features_dim[2]), ) self.mu_layer = nn.Linear(features_dim[2], zd_dim) self.log_sigma2_layer = nn.Linear(features_dim[2], zd_dim)
[docs] def get_z(self, x): mu, log_sigma2 = self.forward(x) return mu
[docs] def get_log_sigma2(self, x): mu, log_sigma2 = self.forward(x) return log_sigma2
[docs] def forward(self, x): """ :param x: input data, assumed to have 3 channels """ assert x.shape[1] == 3 x = torch.reshape(x, (x.shape[0], 3 * x.shape[2] * x.shape[3])) z = self.encod(x) mu = self.mu_layer(z) log_sigma2 = self.log_sigma2_layer(z) return mu, log_sigma2
[docs]class LinearDecoder(nn.Module):
[docs] def __init__(self, prior, zd_dim, input_dim=(3, 28, 28), features_dim=[500, 500, 2000]): """ VAE Decoder :param zd_dim: dimension of the latent space :param input_dim: dimension of the original input / output reconstruction, e.g., (3, 28, 28) for MNIST in RGB format :param features_dim: list of dimensions of the hidden layers, given in reverse order """ self.prior = prior super(LinearDecoder, self).__init__() self.input_dim = input_dim self.decod = nn.Sequential( *linear_block(zd_dim, features_dim[2]), *linear_block(features_dim[2], features_dim[1]), *linear_block(features_dim[1], features_dim[0]), ) self.mu_layer = nn.Linear(features_dim[0], np.prod(input_dim)) self.log_sigma_layer = nn.Linear(features_dim[0], np.prod(input_dim)) self.activation = nn.Sigmoid()
[docs] def forward(self, z): """ :param z: latent space representation :return x_pro: reconstructed data, which is assumed to have 3 channels, but the channels are assumed to be equal to each other. """ x_decoded = self.decod(z) if self.prior == "Bern": # if Bernoulli distribution sigmoid activation to mu is applied x_pro = self.mu_layer(x_decoded) x_pro = self.activation(x_pro) else: x_pro = self.mu_layer(x_decoded) log_sigma = self.log_sigma_layer(x_decoded) x_pro = torch.reshape(x_pro, (x_pro.shape[0], *self.input_dim)) log_sigma = torch.reshape(log_sigma, (log_sigma.shape[0], *self.input_dim)) return x_pro, log_sigma