Source code for domid.compos.cnn_AE

import torch
import torch.nn as nn
import torch.nn.functional as F

from domid.compos.VAE_blocks import UnFlatten, cnn_decoding_block, cnn_encoding_block, get_output_shape


[docs]class ConvolutionalEncoder(nn.Module):
[docs] def __init__(self, zd_dim, num_channels=3, num_filters=[32, 64, 128], i_w=28, i_h=28, k=[3, 3, 3]): """ AE Encoder :param zd_dim: dimension of the latent space :param num_channels: number of channels of the input :param num_filters: list of number of filters for each convolutional layer :param i_w: width of the input :param i_h: height of the input :param k: list of kernel sizes for each convolutional layer """ super(ConvolutionalEncoder, self).__init__() self.num_filters = num_filters self.conv_block1 = nn.Conv2d(num_channels, num_filters[0], k[0], stride=2, padding=1) self.conv_block2 = nn.Conv2d(num_filters[0], num_filters[1], k[1], stride=2, padding=1) self.conv_block3 = nn.Conv2d(num_filters[1], num_filters[2], k[2], stride=2, padding=1) modules = [self.conv_block1, self.conv_block2, self.conv_block3] self.encod = nn.Sequential(*modules) hidden_output = get_output_shape(self.encod, (1, num_channels, i_w, i_h)) self.h_dim = hidden_output[1] * hidden_output[2] * hidden_output[3] self.z_layer = nn.Linear(self.h_dim, zd_dim) self.bsnorm1 = nn.BatchNorm2d(num_filters[0]) self.bsnorm2 = nn.BatchNorm2d(num_filters[1]) self.bsnorm3 = nn.BatchNorm2d(num_filters[2])
[docs] def get_z(self, x): *_, z = self.forward(x) return z
[docs] def get_log_sigma2(self, x): return None
[docs] def forward(self, x): """ :param x: input data """ enc_h1 = self.conv_block1(x) enc_h1 = F.relu(self.bsnorm1(enc_h1)) enc_h2 = self.conv_block2(enc_h1) enc_h2 = F.relu(self.bsnorm2(enc_h2)) enc_h3 = self.conv_block3(enc_h2) enc_h3 = F.relu(self.bsnorm3(enc_h3)) z = self.z_layer(torch.flatten(enc_h3, 1, -1)) return enc_h1, enc_h2, enc_h3, z
[docs]class ConvolutionalDecoder(nn.Module):
[docs] def __init__( self, prior, zd_dim, domain_dim, h_dim, num_channels=3, num_filters=[32, 64, 128], k=[4, 4, 4] ): # , 256, 512, 1024]): """ AE Decoder :param zd_dim: dimension of the latent space, which is the input space of the decoder :param h_dim: dimension of the first hidden layer, which is a linear layer :param num_channels: number of channels of the output; the output will have twice as many channels, e.g., 3 channels for the mean and 3 channels for log-sigma if num_channels is 3 :param num_filters: list of number of filters for each convolutional layer, given in *reverse* order :param k: list of kernel sizes for each convolutional layer """ # FIXME kernel size as an input from comand line? different for HER2 and MNIST super(ConvolutionalDecoder, self).__init__() self.prior = prior self.num_channels = num_channels self.linear = nn.Linear(zd_dim + domain_dim, h_dim) self.sigmoid_layer = nn.Sigmoid() self.unflat = UnFlatten(num_filters[-1]) num_filters = [num_channels] + num_filters num_filters.reverse() modules = [] for i in range(len(num_filters) - 2): modules.append( nn.ConvTranspose2d(num_filters[i], num_filters[i + 1], kernel_size=k[i], stride=2, padding=1) ) modules.append(nn.BatchNorm2d(num_filters[i + 1])) modules.append(nn.LeakyReLU()) modules.append(nn.ConvTranspose2d(num_filters[-2], num_channels, kernel_size=k[-1], stride=2, padding=1)) self.decod = nn.Sequential(*modules)
[docs] def forward(self, z): """ :param z: latent space representation :return x_pro: reconstructed data, which is assumed to have 3 channels, but the channels are assumed to be equal to each other. """ z = self.linear(z) z = self.unflat(z) x_decoded = self.decod(z) if self.prior == "Bern": x_pro = self.sigmoid_layer(x_decoded[:, 0 : self.num_channels, :, :]) else: x_pro = x_decoded[:, 0 : self.num_channels, :, :] return x_pro