Source code for domid.compos.GNN
import torch
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from domid.compos.GNN_layer import GNNLayer
[docs]class GNN(Module):
[docs] def __init__(self, n_input, n_enc_1, n_enc_2, n_enc_3, n_z, n_clusters, device):
super(GNN, self).__init__()
self.gnn_1 = GNNLayer(n_input, n_enc_1, device)
self.gnn_2 = GNNLayer(n_enc_1, n_enc_2, device)
self.gnn_3 = GNNLayer(n_enc_2, n_enc_3, device)
self.gnn_4 = GNNLayer(n_enc_3, n_z, device)
self.gnn_5 = GNNLayer(n_z, n_clusters, device)
def _flatten_if_needed(self, x):
return torch.flatten(x, 1, -1) if len(x.shape) > 2 else x
[docs] def forward(self, x, adj, tra1, tra2, tra3, z, sigma=0.5):
"""
:param x: image batch
:param adj: adjacency matrix from the constructed graph for the batch of images
:param tra1: features from the first layer of the encoder
:param tra2: features from the second layer of the encoder
:param tra3: features from the third layer of the encoder
:param z: latent features from the encoder
:param sigma:
:return: hidden layer that is used for clustering
"""
x, tra1, tra2, tra3 = map(self._flatten_if_needed, (x, tra1, tra2, tra3))
h = self.gnn_1(x, adj)
h = self.gnn_2((1 - sigma) * h + sigma * tra1, adj)
h = self.gnn_3((1 - sigma) * h + sigma * tra2, adj)
h = self.gnn_4((1 - sigma) * h + sigma * tra3, adj)
h = self.gnn_5((1 - sigma) * h + sigma * z, adj, activation=False)
return h