Source code for domid.compos.VAE_blocks
import numpy as np
import torch
import torch.nn as nn
[docs]def get_output_shape(model, image_dim):
return model(torch.rand(*(image_dim))).data.shape
[docs]def cnn_encoding_block(in_c, out_c, kernel_size=(4, 4), stride=2, padding=1):
layers = [
nn.Conv2d(in_c, out_c, kernel_size, stride, padding),
nn.BatchNorm2d(out_c),
nn.LeakyReLU(), # negative slope
]
return layers
[docs]def cnn_decoding_block(in_c, out_c, kernel_size=(3, 3), stride=2, padding=1):
layers = [nn.ConvTranspose2d(in_c, out_c, kernel_size, stride, padding), nn.BatchNorm2d(out_c), nn.LeakyReLU()]
return layers
[docs]class UnFlatten(nn.Module):
[docs] def __init__(self, num_channels):
super(UnFlatten, self).__init__()
self.num_channels = num_channels
[docs] def forward(self, input):
filter_size = self.num_channels
N = int(np.sqrt(input.shape[1] / filter_size))
return input.view(input.size(0), filter_size, N, N)
[docs]def linear_block(in_c, out_c):
layers = [nn.Linear(in_c, out_c), nn.ReLU(True)]
return layers