import torch
import torch.nn as nn
import torch.nn.functional as F
from domid.compos.VAE_blocks import UnFlatten, cnn_decoding_block, cnn_encoding_block, get_output_shape
[docs]class ConvolutionalEncoder(nn.Module):
[docs] def __init__(self, zd_dim, num_channels=3, num_filters=[32, 64, 128], i_w=28, i_h=28, k=[3, 3, 3]):
"""
AE Encoder
:param zd_dim: dimension of the latent space
:param num_channels: number of channels of the input
:param num_filters: list of number of filters for each convolutional layer
:param i_w: width of the input
:param i_h: height of the input
:param k: list of kernel sizes for each convolutional layer
"""
super(ConvolutionalEncoder, self).__init__()
self.num_filters = num_filters
self.conv_block1 = nn.Conv2d(num_channels, num_filters[0], k[0], stride=2, padding=1)
self.conv_block2 = nn.Conv2d(num_filters[0], num_filters[1], k[1], stride=2, padding=1)
self.conv_block3 = nn.Conv2d(num_filters[1], num_filters[2], k[2], stride=2, padding=1)
modules = [self.conv_block1, self.conv_block2, self.conv_block3]
self.encod = nn.Sequential(*modules)
hidden_output = get_output_shape(self.encod, (1, num_channels, i_w, i_h))
self.h_dim = hidden_output[1] * hidden_output[2] * hidden_output[3]
self.z_layer = nn.Linear(self.h_dim, zd_dim)
self.bsnorm1 = nn.BatchNorm2d(num_filters[0])
self.bsnorm2 = nn.BatchNorm2d(num_filters[1])
self.bsnorm3 = nn.BatchNorm2d(num_filters[2])
[docs] def get_z(self, x):
*_, z = self.forward(x)
return z
[docs] def get_log_sigma2(self, x):
return None
[docs] def forward(self, x):
"""
:param x: input data
"""
enc_h1 = self.conv_block1(x)
enc_h1 = F.relu(self.bsnorm1(enc_h1))
enc_h2 = self.conv_block2(enc_h1)
enc_h2 = F.relu(self.bsnorm2(enc_h2))
enc_h3 = self.conv_block3(enc_h2)
enc_h3 = F.relu(self.bsnorm3(enc_h3))
z = self.z_layer(torch.flatten(enc_h3, 1, -1))
return enc_h1, enc_h2, enc_h3, z
[docs]class ConvolutionalDecoder(nn.Module):
[docs] def __init__(
self, prior, zd_dim, domain_dim, h_dim, num_channels=3, num_filters=[32, 64, 128], k=[4, 4, 4]
): # , 256, 512, 1024]):
"""
AE Decoder
:param zd_dim: dimension of the latent space, which is the input space of the decoder
:param h_dim: dimension of the first hidden layer, which is a linear layer
:param num_channels: number of channels of the output; the output will have twice as many channels, e.g., 3 channels for the mean and 3 channels for log-sigma if num_channels is 3
:param num_filters: list of number of filters for each convolutional layer, given in *reverse* order
:param k: list of kernel sizes for each convolutional layer
"""
# FIXME kernel size as an input from comand line? different for HER2 and MNIST
super(ConvolutionalDecoder, self).__init__()
self.prior = prior
self.num_channels = num_channels
self.linear = nn.Linear(zd_dim + domain_dim, h_dim)
self.sigmoid_layer = nn.Sigmoid()
self.unflat = UnFlatten(num_filters[-1])
num_filters = [num_channels] + num_filters
num_filters.reverse()
modules = []
for i in range(len(num_filters) - 2):
modules.append(
nn.ConvTranspose2d(num_filters[i], num_filters[i + 1], kernel_size=k[i], stride=2, padding=1)
)
modules.append(nn.BatchNorm2d(num_filters[i + 1]))
modules.append(nn.LeakyReLU())
modules.append(nn.ConvTranspose2d(num_filters[-2], num_channels, kernel_size=k[-1], stride=2, padding=1))
self.decod = nn.Sequential(*modules)
[docs] def forward(self, z):
"""
:param z: latent space representation
:return x_pro: reconstructed data, which is assumed to have 3 channels, but the channels are assumed to be equal to each other.
"""
z = self.linear(z)
z = self.unflat(z)
x_decoded = self.decod(z)
if self.prior == "Bern":
x_pro = self.sigmoid_layer(x_decoded[:, 0 : self.num_channels, :, :])
else:
x_pro = x_decoded[:, 0 : self.num_channels, :, :]
return x_pro