Source code for domid.compos.GNN

import torch
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

from domid.compos.GNN_layer import GNNLayer


[docs]class GNN(Module):
[docs] def __init__(self, n_input, n_enc_1, n_enc_2, n_enc_3, n_z, n_clusters, device): super(GNN, self).__init__() self.gnn_1 = GNNLayer(n_input, n_enc_1, device) self.gnn_2 = GNNLayer(n_enc_1, n_enc_2, device) self.gnn_3 = GNNLayer(n_enc_2, n_enc_3, device) self.gnn_4 = GNNLayer(n_enc_3, n_z, device) self.gnn_5 = GNNLayer(n_z, n_clusters, device)
def _flatten_if_needed(self, x): return torch.flatten(x, 1, -1) if len(x.shape) > 2 else x
[docs] def forward(self, x, adj, tra1, tra2, tra3, z, sigma=0.5): """ :param x: image batch :param adj: adjacency matrix from the constructed graph for the batch of images :param tra1: features from the first layer of the encoder :param tra2: features from the second layer of the encoder :param tra3: features from the third layer of the encoder :param z: latent features from the encoder :param sigma: :return: hidden layer that is used for clustering """ x, tra1, tra2, tra3 = map(self._flatten_if_needed, (x, tra1, tra2, tra3)) h = self.gnn_1(x, adj) h = self.gnn_2((1 - sigma) * h + sigma * tra1, adj) h = self.gnn_3((1 - sigma) * h + sigma * tra2, adj) h = self.gnn_4((1 - sigma) * h + sigma * tra3, adj) h = self.gnn_5((1 - sigma) * h + sigma * z, adj, activation=False) return h