import torch.nn as nn
from domid.compos.VAE_blocks import UnFlatten, get_output_shape
[docs]class ConvolutionalEncoder(nn.Module):
[docs] def __init__(self, zd_dim, num_channels=3, num_filters=[32, 64, 128], i_w=28, i_h=28, k=[3, 3, 3]):
"""
VAE Encoder
:param zd_dim: dimension of the latent space
:param num_channels: number of channels of the input
:param num_filters: list of number of filters for each convolutional layer
:param i_w: width of the input
:param i_h: height of the input
:param k: list of kernel sizes for each convolutional layer
"""
super(ConvolutionalEncoder, self).__init__()
modules = []
num_filters = [num_channels] + num_filters
for i in range(len(num_filters) - 1):
modules.append(nn.Conv2d(num_filters[i], num_filters[i + 1], kernel_size=k[i], stride=2, padding=1))
modules.append(nn.BatchNorm2d(num_filters[i + 1]))
modules.append(nn.LeakyReLU())
modules.append(nn.Flatten())
self.encod = nn.Sequential(*modules)
self.h_dim = get_output_shape(self.encod, (1, num_channels, i_w, i_h))[1]
self.mu_layer = nn.Linear(self.h_dim, zd_dim)
self.log_sigma2_layer = nn.Linear(self.h_dim, zd_dim)
[docs] def get_z(self, x):
mu, _ = self.forward(x)
return mu
[docs] def get_log_sigma2(self, x):
_, log_sigma2 = self.forward(x)
return log_sigma2
[docs] def forward(self, x):
"""
:param x: input data
"""
z = self.encod(x)
mu = self.mu_layer(z)
log_sigma2 = self.log_sigma2_layer(z)
return mu, log_sigma2
[docs]class ConvolutionalDecoder(nn.Module):
[docs] def __init__(
self, prior, zd_dim, domain_dim, h_dim, num_channels=3, num_filters=[32, 64, 128], k=[4, 4, 4]
): # , 256, 512, 1024]):
"""
VAE Decoder
:param zd_dim: dimension of the latent space, which is the input space of the decoder
:param h_dim: dimension of the first hidden layer, which is a linear layer
:param num_channels: number of channels of the output; the output will have twice as many channels, e.g., 3 channels for the mean and 3 channels for log-sigma if num_channels is 3
:param num_filters: list of number of filters for each convolutional layer, given in *reverse* order
:param k: list of kernel sizes for each convolutional layer
"""
# FIXME kernel size as an input from comand line? different for HER2 and MNIST
super(ConvolutionalDecoder, self).__init__()
self.prior = prior
self.num_channels = num_channels
self.linear = nn.Linear(zd_dim + domain_dim, h_dim)
self.sigmoid_layer = nn.Sigmoid()
self.unflat = UnFlatten(num_filters[-1])
num_filters = [num_channels] + num_filters
num_filters.reverse()
modules = []
for i in range(len(num_filters) - 2):
modules.append(
nn.ConvTranspose2d(num_filters[i], num_filters[i + 1], kernel_size=k[i], stride=2, padding=1)
)
modules.append(nn.BatchNorm2d(num_filters[i + 1]))
modules.append(nn.LeakyReLU())
modules.append(nn.ConvTranspose2d(num_filters[-2], num_channels * 2, kernel_size=k[-1], stride=2, padding=1))
self.decod = nn.Sequential(*modules)
[docs] def forward(self, z):
"""
:param z: latent space representation
:return x_pro: reconstructed data, which is assumed to have 3 channels, but the channels are assumed to be equal to each other.
:return x_log_sigma2: log-variance of the reconstructed data
"""
z = self.linear(z)
z = self.unflat(z)
x_decoded = self.decod(z)
if self.prior == "Bern":
x_pro = self.sigmoid_layer(x_decoded[:, 0 : self.num_channels, :, :])
else:
x_pro = x_decoded[:, 0 : self.num_channels, :, :]
log_sigma = x_decoded[:, self.num_channels :, :, :]
return x_pro, log_sigma